Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the peer that propagates a message to the recipient #218

Merged
merged 3 commits into from
Oct 18, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 56 additions & 0 deletions floodsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1398,3 +1398,59 @@ func readAllQueuedEvents(ctx context.Context, t *testing.T, sub *Subscription) m
}
return peerState
}

func TestMessageSender(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

const topic = "foobar"

hosts := getNetHosts(t, ctx, 3)
psubs := getPubsubs(ctx, hosts)

var msgs []*Subscription
for _, ps := range psubs {
subch, err := ps.Subscribe(topic)
if err != nil {
t.Fatal(err)
}

msgs = append(msgs, subch)
}

connect(t, hosts[0], hosts[1])
connect(t, hosts[1], hosts[2])

time.Sleep(time.Millisecond * 100)

for i:=0; i < 3; i++ {
for j := 0; j < 100; j++ {
msg := []byte(fmt.Sprintf("%d sent %d", i, j))

psubs[i].Publish(topic, msg)

for k, sub := range msgs {
got, err := sub.Next(ctx)
if err != nil {
t.Fatal(sub.err)
}
if !bytes.Equal(msg, got.Data) {
t.Fatal("got wrong message!")
}

var expectedHost int
if i == k {
expectedHost = i
} else if k != 1 {
expectedHost = 1
} else {
expectedHost = i
}

if got.ReceivedFrom != hosts[expectedHost].ID() {
t.Fatal("got wrong message sender")
}
}
}
}
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ require (
github.com/multiformats/go-multistream v0.1.0
github.com/whyrusleeping/timecache v0.0.0-20160911033111-cfcb2f1abfee
)

go 1.13
aschmahmann marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 19 additions & 23 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type PubSub struct {
topics map[string]map[peer.ID]struct{}

// sendMsg handles messages that have been validated
sendMsg chan *sendReq
sendMsg chan *Message

// addVal handles validator registration requests
addVal chan *addValReq
Expand Down Expand Up @@ -135,6 +135,7 @@ type PubSubRouter interface {

type Message struct {
*pb.Message
ReceivedFrom peer.ID
}

func (m *Message) GetFrom() peer.ID {
Expand Down Expand Up @@ -170,7 +171,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan *sendReq, 32),
sendMsg: make(chan *Message, 32),
addVal: make(chan *addValReq),
rmVal: make(chan *rmValReq),
eval: make(chan func()),
Expand Down Expand Up @@ -373,10 +374,10 @@ func (p *PubSub) processLoop(ctx context.Context) {
p.handleIncomingRPC(rpc)

case msg := <-p.publish:
p.pushMsg(p.host.ID(), msg)
p.pushMsg(msg)

case req := <-p.sendMsg:
p.publishMessage(req.from, req.msg.Message)
case msg := <-p.sendMsg:
p.publishMessage(msg)

case req := <-p.addVal:
p.val.AddValidator(req)
Expand Down Expand Up @@ -522,12 +523,12 @@ func (p *PubSub) doAnnounceRetry(pid peer.ID, topic string, sub bool) {

// notifySubs sends a given message to all corresponding subscribers.
// Only called from processLoop.
func (p *PubSub) notifySubs(msg *pb.Message) {
func (p *PubSub) notifySubs(msg *Message) {
for _, topic := range msg.GetTopicIDs() {
subs := p.myTopics[topic]
for f := range subs {
select {
case f.ch <- &Message{msg}:
case f.ch <- msg:
default:
log.Infof("Can't deliver message to subscription for topic %s; subscriber too slow", topic)
}
Expand Down Expand Up @@ -616,8 +617,8 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) {
continue
}

msg := &Message{pmsg}
p.pushMsg(rpc.from, msg)
msg := &Message{pmsg, rpc.from}
p.pushMsg(msg)
}

p.rt.HandleRPC(rpc)
Expand All @@ -629,7 +630,8 @@ func msgID(pmsg *pb.Message) string {
}

// pushMsg pushes a message performing validation as necessary
func (p *PubSub) pushMsg(src peer.ID, msg *Message) {
func (p *PubSub) pushMsg(msg *Message) {
src := msg.ReceivedFrom
// reject messages from blacklisted peers
if p.blacklist.Contains(src) {
log.Warningf("dropping message from blacklisted peer %s", src)
Expand Down Expand Up @@ -659,13 +661,13 @@ func (p *PubSub) pushMsg(src peer.ID, msg *Message) {
}

if p.markSeen(id) {
p.publishMessage(src, msg.Message)
p.publishMessage(msg)
}
}

func (p *PubSub) publishMessage(from peer.ID, pmsg *pb.Message) {
p.notifySubs(pmsg)
p.rt.Publish(from, pmsg)
func (p *PubSub) publishMessage(msg *Message) {
p.notifySubs(msg)
p.rt.Publish(msg.ReceivedFrom, msg.Message)
}

type addSubReq struct {
Expand Down Expand Up @@ -734,10 +736,11 @@ func (p *PubSub) GetTopics() []string {
// Publish publishes data to the given topic.
func (p *PubSub) Publish(topic string, data []byte) error {
seqno := p.nextSeqno()
id := p.host.ID()
m := &pb.Message{
Data: data,
TopicIDs: []string{topic},
From: []byte(p.host.ID()),
From: []byte(id),
Seqno: seqno,
}
if p.signKey != nil {
Expand All @@ -747,7 +750,7 @@ func (p *PubSub) Publish(topic string, data []byte) error {
return err
}
}
p.publish <- &Message{m}
p.publish <- &Message{m, id}
return nil
}

Expand All @@ -763,13 +766,6 @@ type listPeerReq struct {
topic string
}

// sendReq is a request to call publishMessage.
// It is issued after message validation is done.
type sendReq struct {
from peer.ID
msg *Message
}

// ListPeers returns a list of peers we are connected to in the given topic.
func (p *PubSub) ListPeers(topic string) []peer.ID {
out := make(chan []peer.ID)
Expand Down
10 changes: 2 additions & 8 deletions validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) {
}

// no async validators, send the message
v.p.sendMsg <- &sendReq{
from: src,
msg: msg,
}
v.p.sendMsg <- msg
}

func (v *validation) validateSignature(msg *Message) bool {
Expand All @@ -255,10 +252,7 @@ func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message
return
}

v.p.sendMsg <- &sendReq{
from: src,
msg: msg,
}
v.p.sendMsg <- msg
}

func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) bool {
Expand Down