Skip to content

Commit

Permalink
Merge pull request #400 from libp2p/feat/disable-providers
Browse files Browse the repository at this point in the history
feat: allow disabling value and provider storage/messages
  • Loading branch information
Stebalien committed Dec 12, 2019
2 parents c08274f + 5a048ea commit 2e6adb8
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 22 deletions.
7 changes: 7 additions & 0 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ type IpfsDHT struct {
triggerRtRefresh chan chan<- error

maxRecordAge time.Duration

// Allows disabling dht subsystems. These should _only_ be set on
// "forked" DHTs (e.g., DHTs with custom protocols and/or private
// networks).
enableProviders, enableValues bool
}

// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
Expand All @@ -100,6 +105,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
dht.rtRefreshQueryTimeout = cfg.RoutingTable.RefreshQueryTimeout

dht.maxRecordAge = cfg.MaxRecordAge
dht.enableProviders = cfg.EnableProviders
dht.enableValues = cfg.EnableValues

// register for network notifs.
dht.host.Network().Notify((*netNotifiee)(dht))
Expand Down
76 changes: 72 additions & 4 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ func (testAtomicPutValidator) Select(_ string, bs [][]byte) (int, error) {
return index, nil
}

func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
func setupDHT(ctx context.Context, t *testing.T, client bool, options ...opts.Option) *IpfsDHT {
d, err := New(
ctx,
bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)),
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoRefresh(),
append([]opts.Option{
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoRefresh(),
}, options...)...,
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1500,6 +1502,72 @@ func TestFindClosestPeers(t *testing.T) {
}
}

func TestProvideDisabled(t *testing.T) {
k := testCaseCids[0]
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
optsA, optsB []opts.Option
)
if !enabledA {
optsA = append(optsA, opts.DisableProviders())
}
if !enabledB {
optsB = append(optsB, opts.DisableProviders())
}

dhtA := setupDHT(ctx, t, false, optsA...)
dhtB := setupDHT(ctx, t, false, optsB...)

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

err := dhtB.Provide(ctx, k, true)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.FindProviders(ctx, k)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
provs := dhtB.providers.GetProviders(ctx, k)
if len(provs) != 0 {
t.Fatal("node B should not have found local providers")
}
}

provs, err := dhtA.FindProviders(ctx, k)
if enabledA {
if len(provs) != 0 {
t.Fatal("node A should not have found providers")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found providers")
}
}
provAddrs := dhtA.providers.GetProviders(ctx, k)
if len(provAddrs) != 0 {
t.Fatal("node A should not have found local providers")
}
})
}
}

func TestGetSetPluggedProtocol(t *testing.T) {
t.Run("PutValue/GetValue - same protocol", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
30 changes: 20 additions & 10 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,31 @@ type dhtHandler func(context.Context, peer.ID, *pb.Message) (*pb.Message, error)

func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
case pb.Message_FIND_NODE:
return dht.handleFindPeer
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
case pb.Message_PING:
return dht.handlePing
default:
return nil
}

if dht.enableValues {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
}
}

if dht.enableProviders {
switch t {
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
}
}

return nil
}

func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
Expand Down
43 changes: 37 additions & 6 deletions opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ var (

// Options is a structure containing all the options that can be used when constructing a DHT.
type Options struct {
Datastore ds.Batching
Validator record.Validator
Client bool
Protocols []protocol.ID
BucketSize int
MaxRecordAge time.Duration
Datastore ds.Batching
Validator record.Validator
Client bool
Protocols []protocol.ID
BucketSize int
MaxRecordAge time.Duration
EnableProviders bool
EnableValues bool

RoutingTable struct {
RefreshQueryTimeout time.Duration
Expand Down Expand Up @@ -56,6 +58,8 @@ var Defaults = func(o *Options) error {
}
o.Datastore = dssync.MutexWrap(ds.NewMapDatastore())
o.Protocols = DefaultProtocols
o.EnableProviders = true
o.EnableValues = true

o.RoutingTable.RefreshQueryTimeout = 10 * time.Second
o.RoutingTable.RefreshPeriod = 1 * time.Hour
Expand Down Expand Up @@ -177,3 +181,30 @@ func DisableAutoRefresh() Option {
return nil
}
}

// DisableProviders disables storing and retrieving provider records.
//
// Defaults to enabled.
//
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
// network and/or distinct DHT protocols with the `Protocols` option).
func DisableProviders() Option {
return func(o *Options) error {
o.EnableProviders = false
return nil
}
}

// DisableProviders disables storing and retrieving value records (including
// public keys).
//
// Defaults to enabled.
//
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
// network and/or distinct DHT protocols with the `Protocols` option).
func DisableValues() Option {
return func(o *Options) error {
o.EnableValues = false
return nil
}
}
75 changes: 75 additions & 0 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dht
import (
"context"
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p-core/test"
"testing"
"time"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/libp2p/go-libp2p-core/routing"
record "github.com/libp2p/go-libp2p-record"
tnet "github.com/libp2p/go-libp2p-testing/net"

dhtopt "github.com/libp2p/go-libp2p-kad-dht/opts"
)

// Check that GetPublicKey() correctly extracts a public key
Expand Down Expand Up @@ -305,3 +308,75 @@ func TestPubkeyGoodKeyFromDHTGoodKeyDirect(t *testing.T) {
t.Fatal("got incorrect public key")
}
}

func TestValuesDisabled(t *testing.T) {
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
optsA, optsB []dhtopt.Option
)
if !enabledA {
optsA = append(optsA, dhtopt.DisableValues())
}
if !enabledB {
optsB = append(optsB, dhtopt.DisableValues())
}

dhtA := setupDHT(ctx, t, false, optsA...)
dhtB := setupDHT(ctx, t, false, optsB...)

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

pubk := dhtB.peerstore.PubKey(dhtB.self)
pkbytes, err := pubk.Bytes()
if err != nil {
t.Fatal(err)
}

pkkey := routing.KeyForPublicKey(dhtB.self)
err = dhtB.PutValue(ctx, pkkey, pkbytes)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.GetValue(ctx, pkkey)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
rec, _ := dhtB.getLocal(pkkey)
if rec != nil {
t.Fatal("node B should not have found the value locally")
}
}

_, err = dhtA.GetValue(ctx, pkkey)
if enabledA {
if err != routing.ErrNotFound {
t.Fatal("node A should not have found the value")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found the value")
}
}
rec, _ := dhtA.getLocal(pkkey)
if rec != nil {
t.Fatal("node A should not have found the value locally")
}
})
}
}
30 changes: 28 additions & 2 deletions routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ var asyncQueryBuffer = 10
// PutValue adds value corresponding to given Key.
// This is the top level "Store" operation of the DHT
func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) {
if !dht.enableValues {
return routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "PutValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -110,6 +114,10 @@ type RecvdVal struct {

// GetValue searches for the value corresponding to given Key.
func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Option) (_ []byte, err error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -148,6 +156,10 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op
}

func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

var cfg routing.Options
if err := cfg.Apply(opts...); err != nil {
return nil, err
Expand Down Expand Up @@ -250,8 +262,11 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing

// GetValues gets nvals values corresponding to the given key.
func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
eip := logger.EventBegin(ctx, "GetValues")
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValues")
eip.Append(loggableKey(key))
defer eip.Done()

Expand Down Expand Up @@ -398,6 +413,9 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-cha

// Provide makes this node announce that it can provide a value for the given key
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
if !dht.enableProviders {
return routing.ErrNotSupported
}
eip := logger.EventBegin(ctx, "Provide", key, logging.LoggableMap{"broadcast": brdcst})
defer func() {
if err != nil {
Expand Down Expand Up @@ -477,6 +495,9 @@ func (dht *IpfsDHT) makeProvRecord(skey cid.Cid) (*pb.Message, error) {

// FindProviders searches until the context expires.
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
if !dht.enableProviders {
return nil, routing.ErrNotSupported
}
var providers []peer.AddrInfo
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
providers = append(providers, p)
Expand All @@ -488,8 +509,13 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
// Peers will be returned on the channel as soon as they are found, even before
// the search query completes.
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
logger.Event(ctx, "findProviders", key)
peerOut := make(chan peer.AddrInfo, count)
if !dht.enableProviders {
close(peerOut)
return peerOut
}

logger.Event(ctx, "findProviders", key)

go dht.findProvidersAsyncRoutine(ctx, key, count, peerOut)
return peerOut
Expand Down

0 comments on commit 2e6adb8

Please sign in to comment.