Skip to content

Commit

Permalink
feat: allow disabling value and provider storage/messages
Browse files Browse the repository at this point in the history
fixes #274
  • Loading branch information
Stebalien committed Oct 31, 2019
1 parent 9c02087 commit d43ce3c
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 14 deletions.
4 changes: 4 additions & 0 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type IpfsDHT struct {
bootstrapCfg opts.BootstrapConfig

triggerBootstrap chan struct{}

enableProviders, enableValues bool
}

// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
Expand All @@ -90,6 +92,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
return nil, err
}
dht := makeDHT(ctx, h, cfg.Datastore, cfg.Protocols, cfg.BucketSize)
dht.enableProviders = cfg.EnableProviders
dht.enableValues = cfg.EnableValues
dht.bootstrapCfg = cfg.BootstrapConfig

// register for network notifs.
Expand Down
64 changes: 61 additions & 3 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,14 @@ func (testAtomicPutValidator) Select(_ string, bs [][]byte) (int, error) {
return index, nil
}

func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
func setupDHT(ctx context.Context, t *testing.T, client bool, options ...opts.Option) *IpfsDHT {
d, err := New(
ctx,
bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)),
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
append([]opts.Option{
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
}, options...)...,
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1407,6 +1409,62 @@ func TestFindClosestPeers(t *testing.T) {
}
}

func TestProvideDisabled(t *testing.T) {
k := testCaseCids[0]
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dhtA := setupDHT(ctx, t, false, opts.EnableProviders(enabledA))
dhtB := setupDHT(ctx, t, false, opts.EnableProviders(enabledB))

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

err := dhtB.Provide(ctx, k, true)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.FindProviders(ctx, k)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
provs := dhtB.providers.GetProviders(ctx, k)
if len(provs) != 0 {
t.Fatal("node B should not have found local providers")
}
}

provs, err := dhtA.FindProviders(ctx, k)
if enabledA {
if len(provs) != 0 {
t.Fatal("node A should not have found providers")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found providers")
}
}
provAddrs := dhtA.providers.GetProviders(ctx, k)
if len(provAddrs) != 0 {
t.Fatal("node A should not have found local providers")
}
})
}
}

func TestGetSetPluggedProtocol(t *testing.T) {
t.Run("PutValue/GetValue - same protocol", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
30 changes: 20 additions & 10 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,31 @@ type dhtHandler func(context.Context, peer.ID, *pb.Message) (*pb.Message, error)

func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
case pb.Message_FIND_NODE:
return dht.handleFindPeer
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
case pb.Message_PING:
return dht.handlePing
default:
return nil
}

if dht.enableValues {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
}
}

if dht.enableProviders {
switch t {
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
}
}

return nil
}

func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
Expand Down
24 changes: 24 additions & 0 deletions opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Options struct {
Protocols []protocol.ID
BucketSize int
BootstrapConfig BootstrapConfig
EnableProviders bool
EnableValues bool
}

// Apply applies the given options to this Option
Expand All @@ -58,6 +60,8 @@ var Defaults = func(o *Options) error {
}
o.Datastore = dssync.MutexWrap(ds.NewMapDatastore())
o.Protocols = DefaultProtocols
o.EnableProviders = true
o.EnableValues = true

o.BootstrapConfig = BootstrapConfig{
// same as that mentioned in the kad dht paper
Expand Down Expand Up @@ -149,3 +153,23 @@ func BucketSize(bucketSize int) Option {
return nil
}
}

// EnableProviders enables storing and retrieving provider records.
//
// Defaults to true.
func EnableProviders(enable bool) Option {
return func(o *Options) error {
o.EnableProviders = enable
return nil
}
}

// EnableValues enables storing and retrieving value records.
//
// Defaults to true.
func EnableValues(enable bool) Option {
return func(o *Options) error {
o.EnableValues = enable
return nil
}
}
65 changes: 65 additions & 0 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dht
import (
"context"
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p-core/test"
"testing"
"time"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/libp2p/go-libp2p-core/routing"
record "github.com/libp2p/go-libp2p-record"
tnet "github.com/libp2p/go-libp2p-testing/net"

dhtopt "github.com/libp2p/go-libp2p-kad-dht/opts"
)

// Check that GetPublicKey() correctly extracts a public key
Expand Down Expand Up @@ -305,3 +308,65 @@ func TestPubkeyGoodKeyFromDHTGoodKeyDirect(t *testing.T) {
t.Fatal("got incorrect public key")
}
}

func TestValuesDisabled(t *testing.T) {
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dhtA := setupDHT(ctx, t, false, dhtopt.EnableValues(enabledA))
dhtB := setupDHT(ctx, t, false, dhtopt.EnableValues(enabledB))

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

pubk := dhtB.peerstore.PubKey(dhtB.self)
pkbytes, err := pubk.Bytes()
if err != nil {
t.Fatal(err)
}

pkkey := routing.KeyForPublicKey(dhtB.self)
err = dhtB.PutValue(ctx, pkkey, pkbytes)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.GetValue(ctx, pkkey)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
rec, _ := dhtB.getLocal(pkkey)
if rec != nil {
t.Fatal("node B should not have found the value locally")
}
}

_, err = dhtA.GetValue(ctx, pkkey)
if enabledA {
if err != routing.ErrNotFound {
t.Fatal("node A should not have found the value")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found the value")
}
}
rec, _ := dhtA.getLocal(pkkey)
if rec != nil {
t.Fatal("node A should not have found the value locally")
}
})
}
}
27 changes: 26 additions & 1 deletion routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ var asyncQueryBuffer = 10
// PutValue adds value corresponding to given Key.
// This is the top level "Store" operation of the DHT
func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) {
if !dht.enableValues {
return routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "PutValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -110,6 +114,10 @@ type RecvdVal struct {

// GetValue searches for the value corresponding to given Key.
func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Option) (_ []byte, err error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -148,6 +156,10 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op
}

func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

var cfg routing.Options
if err := cfg.Apply(opts...); err != nil {
return nil, err
Expand Down Expand Up @@ -250,8 +262,11 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing

// GetValues gets nvals values corresponding to the given key.
func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
eip := logger.EventBegin(ctx, "GetValues")
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValues")
eip.Append(loggableKey(key))
defer eip.Done()

Expand Down Expand Up @@ -398,6 +413,9 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-cha

// Provide makes this node announce that it can provide a value for the given key
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
if !dht.enableProviders {
return routing.ErrNotSupported
}
eip := logger.EventBegin(ctx, "Provide", key, logging.LoggableMap{"broadcast": brdcst})
defer func() {
if err != nil {
Expand Down Expand Up @@ -477,6 +495,9 @@ func (dht *IpfsDHT) makeProvRecord(skey cid.Cid) (*pb.Message, error) {

// FindProviders searches until the context expires.
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
if !dht.enableProviders {
return nil, routing.ErrNotSupported
}
var providers []peer.AddrInfo
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
providers = append(providers, p)
Expand All @@ -488,6 +509,10 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
// Peers will be returned on the channel as soon as they are found, even before
// the search query completes.
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
if !dht.enableProviders {
return nil
}

logger.Event(ctx, "findProviders", key)
peerOut := make(chan peer.AddrInfo, count)

Expand Down

0 comments on commit d43ce3c

Please sign in to comment.