Skip to content
Permalink
Browse files

Update to use xor metric satisfying the unidirectional property.

  • Loading branch information...
jtownson committed Aug 13, 2019
1 parent 4010c9f commit 07268c1e98cae9117ebd97ab47558ddaa6ca7570
@@ -1,16 +1,19 @@
package io.iohk.scalanet.peergroup.kademlia

import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArraySet}
import java.util.concurrent.ConcurrentSkipListSet

import io.iohk.scalanet.peergroup.kademlia.KBuckets._
import scodec.bits.BitVector

import scala.collection.JavaConverters._
import scala.collection.mutable

/**
* Skeletal kbucket implementation.
* @param baseId the nodes own id.
*/
class KBuckets(val baseId: BitVector) {

private val buckets = new ConcurrentHashMap[Int, KBucket]().asScala
private val nodeIds =
new ConcurrentSkipListSet[BitVector](new XorOrdering(baseId)).asScala

add(baseId)

@@ -20,19 +23,8 @@ class KBuckets(val baseId: BitVector) {
* The indices into the kBuckets are defined by their distance from referenceNodeId.
*/
def closestNodes(nodeId: BitVector, n: Int): List[BitVector] = {
// sketch
// the id range of a bucket is 2^i bits,
// where i is the common prefix length between the bucket prefix and the base nodeId.
// each kbucket covers a distance range [2^i, 2^i+1 - 1]
// this means that buckets with larger i cover larger ranges.
// so, for eg, bucket 1 covers 4-2-1=1 node whereas bucket 7 covers 256-128-1=127 nodes
// in general each bucket covers 2^i+1 - 2^i - 1 = 2^i - 1

// therefore when finding an arbitrary nodeId
// find the prefix

// replace with routing tree described in kademlia paper...
iterator.toList.sorted(new XorOrdering(nodeId)).take(n)
nodeIds.toList.sorted(new XorOrdering(nodeId)).take(n)
}

/**
@@ -46,11 +38,7 @@ class KBuckets(val baseId: BitVector) {
s"Illegal attempt to add node id with a length different than the this node id."
)

val d = Xor.d(nodeId, this.baseId)

val bucket: KBucket = buckets.getOrElseUpdate(d, newKBucket)

bucket.add(nodeId)
nodeIds.add(nodeId)

this
}
@@ -61,20 +49,16 @@ class KBuckets(val baseId: BitVector) {
* @return true if present
*/
def contains(nodeId: BitVector): Boolean = {
buckets.getOrElse(Xor.d(nodeId, this.baseId), Set.empty[BitVector]).contains(nodeId)
nodeIds.contains(nodeId)
}

/**
* Iterate over all elements in the KBuckets.
*
* @return an iterator
*/
def iterator: Iterator[BitVector] = new Iterator[BitVector] {
val flatIterator: Iterator[BitVector] = buckets.flatMap(_._2).iterator

override def hasNext: Boolean = flatIterator.hasNext

override def next(): BitVector = flatIterator.next()
def iterator: Iterator[BitVector] = {
nodeIds.iterator
}

/**
@@ -83,27 +67,11 @@ class KBuckets(val baseId: BitVector) {
* @param nodeId the nodeId to remove
* @return
*/
def remove(nodeId: BitVector): KBuckets = {
buckets
.get(Xor.d(nodeId, this.baseId))
.foreach(
bucket =>
bucket
.find(_ == nodeId)
.foreach(record => bucket.remove(record))
)
this
def remove(nodeId: BitVector): Boolean = {
nodeIds.remove(nodeId)
}

override def toString: String = {
s"KBuckets(baseId = ${baseId.toHex}): ${iterator.toList.sorted(new XorOrdering(baseId)).map(id => id.toHex).mkString(", ")}"
s"KBuckets(baseId = ${baseId.toHex}): ${nodeIds.toList.map(id => s"(id=${id.toHex}, d=${Xor.d(id, baseId)})").mkString(", ")}"
}

private def newKBucket: KBucket =
new CopyOnWriteArraySet[BitVector]().asScala
}

object KBuckets {

type KBucket = mutable.Set[BitVector]
}
@@ -4,20 +4,8 @@ import scodec.bits.BitVector

object Xor {

def d(a: BitVector, b: BitVector): Int = {
def d(a: BitVector, b: BitVector): BigInt = {
assert(a.length == b.length)
a.length.toInt - leadingZeros(a.xor(b))
}

private def leadingZeros(b: BitVector): Int = {
@annotation.tailrec
def loop(count: Int): Int = {
if (count != b.length && !b.get(count)) {
loop(count + 1)
} else {
count
}
}
loop(0)
BigInt((a xor b).toBin, 2)
}
}
@@ -51,10 +51,22 @@ object Generators {
l.toList
}

def genBitVectorTripsExhaustive(
bitLength: Int
): List[(BitVector, BitVector, BitVector)] = {
for {
x <- genBitVectorExhaustive(bitLength)
y <- genBitVectorExhaustive(bitLength)
z <- genBitVectorExhaustive(bitLength)
} yield (x, y, z)
}

def aRandomBitVector(bitLength: Int = defaultBitLength): BitVector =
BitVector.bits(Range(0, bitLength).map(_ => Random.nextBoolean()))

def aRandomNodeRecord(bitLength: Int = defaultBitLength): NodeRecord[String] = {
def aRandomNodeRecord(
bitLength: Int = defaultBitLength
): NodeRecord[String] = {
NodeRecord(
id = aRandomBitVector(bitLength),
routingAddress = Random.alphanumeric.take(4).mkString,
@@ -38,10 +38,16 @@ class XorSpec extends FlatSpec {
}
}

it should "provide the correct maximal distance" in forAll(posNum[Byte]) { bitCount =>
it should "provide the correct maximal distance" in forAll(posNum[Int]) { bitCount =>
val zero = BitVector.low(bitCount)
val max = BitVector.high(bitCount)

d(zero, max) shouldBe bitCount
d(zero, max) shouldBe BigInt(2).pow(bitCount) - 1
}

it should "satisfy the unidirectional property (from the last para of section 2.1)" in
genBitVectorTripsExhaustive(4).foreach {
case (x, y, z) =>
if (y != z)
d(x, y) should not be d(x, z)
}
}

0 comments on commit 07268c1

Please sign in to comment.
You can’t perform that action at this time.