Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remote PeerId check on dial #215

Merged
merged 11 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ class Multiaddr(val components: List<Pair<Protocol, ByteArray>>) {
*/
constructor(bytes: ByteArray) : this(parseBytes(bytes.toByteBuf()))

constructor(parentAddr: Multiaddr, childAddr: Multiaddr) :
this(concatProtocols(parentAddr, childAddr))

constructor(parentAddr: Multiaddr, peerId: PeerId) :
this(concatPeerId(parentAddr, peerId))

/**
* Returns only components matching any of supplied protocols
*/
Expand Down Expand Up @@ -90,16 +84,51 @@ class Multiaddr(val components: List<Pair<Protocol, ByteArray>>) {
*/
fun getBytes(): ByteArray = writeBytes(Unpooled.buffer()).toByteArray()

fun toPeerIdAndAddr(): Pair<PeerId, Multiaddr> {
if (!has(Protocol.IPFS))
throw IllegalArgumentException("Multiaddr has no peer id")
/**
* Returns [PeerId] from either `/ipfs/` or `/p2p/` component value. `null` if none of those components exists
*/
fun getPeerId(): PeerId? =
components.filter { it.first in Protocol.PEER_ID_PROTOCOLS }.map { PeerId(it.second) }.firstOrNull()

return Pair(
PeerId.fromBase58(getStringComponent(Protocol.IPFS)!!),
Multiaddr(components.subList(0, components.lastIndex))
)
/**
* Appends `/p2p/` component if absent or checks that existing and supplied ids are equal
* @throws IllegalArgumentException if existing `/p2p/` identity doesn't match [peerId]
*/
fun withP2P(peerId: PeerId) = withComponent(Protocol.P2P, peerId.bytes)

/**
* Appends new component if absent or checks that existing and supplied component values are equal
* @throws IllegalArgumentException if existing component value doesn't match [value]
*/
fun withComponent(protocol: Protocol, value: ByteArray): Multiaddr {
val curVal = getComponent(protocol)
return if (curVal != null) {
if (!curVal.contentEquals(value)) {
throw IllegalArgumentException("Value (${protocol.bytesToAddress(value)}) for $protocol doesn't match existing value in $this")
} else {
this
}
} else {
Multiaddr(this.components + (protocol to value))
}
}

/**
* Returns [Multiaddr] with concatenated components of `this` and [other] `Multiaddr`
* No cross component checks or merge is performed
*/
fun concatenated(other: Multiaddr) = Multiaddr(this.components + other.components)

/**
* Merges components of this [Multiaddr] with [other]
* Has the same effect as appending [other] components subsequently by [withComponent]
* @throws IllegalArgumentException if any of `this` component value doesn't match the value for the same protocol in [other]
*/
fun merged(other: Multiaddr) = other.components
.fold(this) { accumulator, component ->
accumulator.withComponent(component.first, component.second)
}

internal fun split(pred: (Protocol) -> Boolean): List<Multiaddr> {
val addresses = mutableListOf<Multiaddr>()
split(
Expand Down Expand Up @@ -199,17 +228,5 @@ class Multiaddr(val components: List<Pair<Protocol, ByteArray>>) {
}
return ret
}

private fun concatProtocols(parentAddr: Multiaddr, childAddr: Multiaddr): List<Pair<Protocol, ByteArray>> {
return parentAddr.components + childAddr.components
}

private fun concatPeerId(addr: Multiaddr, peerId: PeerId): List<Pair<Protocol, ByteArray>> {
if (addr.has(Protocol.IPFS))
throw IllegalArgumentException("Multiaddr already has peer id")
val protocols = addr.components.toMutableList()
protocols.add(Pair(Protocol.IPFS, peerId.bytes))
return protocols
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MultiaddrDns {
else
addressMatrix[0].flatMap { parent ->
crossProduct(addressMatrix.subList(1, addressMatrix.size))
.map { child -> Multiaddr(parent, child) }
.map { child -> parent.concatenated(child) }
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/main/kotlin/io/libp2p/core/multiformats/Protocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ enum class Protocol(val code: Int, val size: Int, val typeName: String) {
}

companion object {
@JvmStatic
val PEER_ID_PROTOCOLS = listOf(P2P, IPFS)

private val byCode = values().associate { p -> p.code to p }
private val byName = values().associate { p -> p.typeName to p }

Expand Down
5 changes: 3 additions & 2 deletions src/main/kotlin/io/libp2p/core/multistream/ProtocolBinding.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ interface ProtocolBinding<out TController> {
*/
@JvmDefault
fun dial(host: Host, addrWithPeer: Multiaddr): StreamPromise<out TController> {
val (peerId, addr) = addrWithPeer.toPeerIdAndAddr()
return dial(host, peerId, addr)
val peerId = addrWithPeer.getPeerId()
?: throw IllegalArgumentException("Expected remote peer ID in the dial Multiaddr: $addrWithPeer")
return dial(host, peerId, addrWithPeer)
}

@JvmDefault
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/io/libp2p/host/HostImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class HostImpl(

network.transports.forEach {
listening.addAll(
it.listenAddresses().map { Multiaddr(it, peerId) }
it.listenAddresses().map { it.withP2P(peerId) }
)
}

Expand Down
8 changes: 5 additions & 3 deletions src/main/kotlin/io/libp2p/network/NetworkImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ class NetworkImpl(
connections.find { it.secureSession().remoteId == id }
?.apply { return CompletableFuture.completedFuture(this) }

val addrsWithP2P = addrs.map { it.withP2P(id) }

// 1. check that some transport can dial at least one addr.
// 2. trigger dials in parallel via all transports.
// 3. when the first dial succeeds, cancel all pending dials and return the connection. // TODO cancel
// 4. if no emitted dial succeeds, or if we time out, fail the future. make sure to cancel
// pending dials to avoid leaking.
val connectionFuts = addrs.mapNotNull { addr ->
val connectionFuts = addrsWithP2P.mapNotNull { addr ->
transports.firstOrNull { tpt -> tpt.handles(addr) }?.let { addr to it }
}.map {
it.second.dial(it.first, createHookedConnHandler(connectionHandler))
}.map { (addr, transport) ->
transport.dial(addr, createHookedConnHandler(connectionHandler))
}
return anyComplete(connectionFuts)
}
Expand Down
5 changes: 4 additions & 1 deletion src/main/kotlin/io/libp2p/security/SecureChannelError.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ open class SecureChannelError : Exception {
constructor(message: String) : super(message)
}

open class SecureHandshakeError : SecureChannelError()
open class SecureHandshakeError : SecureChannelError {
constructor() : super()
constructor(message: String) : super(message)
}

class InvalidRemotePubKey : SecureHandshakeError()
class InvalidInitialPacket : SecureHandshakeError()
Expand Down
14 changes: 11 additions & 3 deletions src/main/kotlin/io/libp2p/security/noise/NoiseXXSecureChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.libp2p.core.crypto.marshalPublicKey
import io.libp2p.core.crypto.unmarshalPublicKey
import io.libp2p.core.multistream.ProtocolDescriptor
import io.libp2p.core.security.SecureChannel
import io.libp2p.etc.REMOTE_PEER_ID
import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.types.toByteBuf
import io.libp2p.etc.types.toUShortBigEndian
Expand Down Expand Up @@ -97,6 +98,7 @@ private class NoiseIoHandshake(
private var instancePayload: ByteArray? = null
private var activated = false
private var remotePeerId: PeerId? = null
private var expectedRemotePeerId: PeerId? = null

init {
log.debug("Starting handshake")
Expand All @@ -117,6 +119,10 @@ private class NoiseIoHandshake(
// the Noise protocol only permits alice to send a packet first
if (role == Role.INIT) {
sendNoiseMessage(ctx)
if (!ctx.channel().hasAttr(REMOTE_PEER_ID)) {
throw SecureHandshakeError("Remote Peer ID missing for initiating party")
}
expectedRemotePeerId = ctx.channel().attr(REMOTE_PEER_ID).get()
}
} // channelActive

Expand All @@ -131,9 +137,11 @@ private class NoiseIoHandshake(

// verify the signature of the remote's noise static public key once
// the remote public key has been provided by the XX protocol
with(handshakeState.remotePublicKey) {
if (hasPublicKey()) {
remotePeerId = verifyPayload(ctx, instancePayload!!, this)
val derivedRemotePublicKey = handshakeState.remotePublicKey
if (derivedRemotePublicKey.hasPublicKey()) {
remotePeerId = verifyPayload(ctx, instancePayload!!, derivedRemotePublicKey)
if (role == Role.INIT && expectedRemotePeerId != remotePeerId) {
throw InvalidRemotePubKey()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/test/java/io/libp2p/core/HostTestJava.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void ping() throws Exception {
Assertions.assertEquals(0, clientHost.listenAddresses().size());
Assertions.assertEquals(1, serverHost.listenAddresses().size());
Assertions.assertEquals(
localListenAddress + "/ipfs/" + serverHost.getPeerId(),
localListenAddress + "/p2p/" + serverHost.getPeerId(),
serverHost.listenAddresses().get(0).toString()
);

Expand Down
56 changes: 43 additions & 13 deletions src/test/kotlin/io/libp2p/core/multiformats/MultiaddrTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -196,39 +196,69 @@ class MultiaddrTest {
val parentAddr = Multiaddr("/ip4/127.0.0.1/tcp/20000")
val peerId = testPeerId()

val addr = Multiaddr(parentAddr, peerId)
assertEquals("/ip4/127.0.0.1/tcp/20000/ipfs/QmULzn6KtFUCKpkFymEUgUvkLtv9j2Eo4utZPELmQEebR6", addr.toString())
val addr = parentAddr.withP2P(peerId)
assertEquals("/ip4/127.0.0.1/tcp/20000/p2p/QmULzn6KtFUCKpkFymEUgUvkLtv9j2Eo4utZPELmQEebR6", addr.toString())
assertEquals(addr.withP2P(peerId), addr)

assertThrows(java.lang.IllegalArgumentException::class.java) {
Multiaddr(addr, peerId) // parent already has peer id
assertThrows(IllegalArgumentException::class.java) {
addr.withP2P(PeerId.random()) // parent has another peer id
}
}

@Test
fun concatTwoMultiaddrs() {
fun `concatenated() should just concat components`() {
val parentAddr = Multiaddr("/ip4/127.0.0.1/tcp/20000")
val childAddr = Multiaddr("/p2p-circuit/ip4/127.0.0.2")

val addr = parentAddr.concatenated(childAddr)
assertEquals(
"/ip4/127.0.0.1/tcp/20000/p2p-circuit/ip4/127.0.0.2",
addr.toString()
)
}

@Test
fun `merged() should succeed with distinct components`() {
val parentAddr = Multiaddr("/ip4/127.0.0.1/tcp/20000")
val childAddr = Multiaddr("/p2p-circuit/dns4/trousers.org")

val addr = Multiaddr(parentAddr, childAddr)
val addr = parentAddr.merged(childAddr)
assertEquals(
"/ip4/127.0.0.1/tcp/20000/p2p-circuit/dns4/trousers.org",
addr.toString()
)
}

@Test
fun testSplitIntoPeerAndMultiaddr() {
val addr = Multiaddr("/ip4/127.0.0.1/tcp/20000/ipfs/QmULzn6KtFUCKpkFymEUgUvkLtv9j2Eo4utZPELmQEebR6")
fun `merged() should succeed with matching component values`() {
val parentAddr = Multiaddr("/ip4/127.0.0.1/tcp/20000")
val childAddr = Multiaddr("/ip4/127.0.0.1/p2p-circuit/dns4/trousers.org")

val (splitPeerId, addrWithoutPeer) = addr.toPeerIdAndAddr()
assertEquals(testPeerId(), splitPeerId)
assertEquals("/ip4/127.0.0.1/tcp/20000", addrWithoutPeer.toString())
val addr = parentAddr.merged(childAddr)
assertEquals(
"/ip4/127.0.0.1/tcp/20000/p2p-circuit/dns4/trousers.org",
addr.toString()
)
}

@Test
fun `merged() should throw with non-matching component values`() {
val parentAddr = Multiaddr("/ip4/127.0.0.1/tcp/20000")
val childAddr = Multiaddr("/ip4/127.0.0.1/tcp/30000/p2p-circuit/dns4/trousers.org")

assertThrows(java.lang.IllegalArgumentException::class.java) {
addrWithoutPeer.toPeerIdAndAddr()
assertThrows(IllegalArgumentException::class.java) {
parentAddr.merged(childAddr)
}
}

@Test
fun testGetPeerId() {
val addr = Multiaddr("/ip4/127.0.0.1/tcp/20000/p2p/QmULzn6KtFUCKpkFymEUgUvkLtv9j2Eo4utZPELmQEebR6")

assertEquals(testPeerId(), addr.getPeerId())
assertEquals(Multiaddr("/ip4/127.0.0.1/tcp/20000").getPeerId(), null)
}

@ParameterizedTest
@MethodSource("splitParams")
fun splitMultiAddr(addr: Multiaddr, expected: List<String>) {
Expand Down
23 changes: 17 additions & 6 deletions src/test/kotlin/io/libp2p/pubsub/GoInteropTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.libp2p.core.P2PChannelHandler
import io.libp2p.core.PeerId
import io.libp2p.core.Stream
import io.libp2p.core.crypto.KEY_TYPE
import io.libp2p.core.crypto.PrivKey
import io.libp2p.core.crypto.generateKeyPair
import io.libp2p.core.crypto.unmarshalPublicKey
import io.libp2p.core.dsl.host
Expand Down Expand Up @@ -51,7 +52,9 @@ import org.junit.jupiter.api.extension.ExecutionCondition
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.extension.ExtensionContext
import pubsub.pb.Rpc
import java.io.File
import java.nio.charset.StandardCharsets
import java.security.SecureRandom
import java.util.concurrent.CompletableFuture
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -83,6 +86,16 @@ annotation class AssumeP2PAvailable()
@AssumeP2PAvailable
@Tag("interop")
class GoInteropTest {

val idPrivateKey: PrivKey = generateKeyPair(KEY_TYPE.SECP256K1, random = SecureRandom(byteArrayOf(0))).first
val daemonPeerId = PeerId.fromPubKey(idPrivateKey.publicKey())
val skFile = File.createTempFile("p2pd_pkey_", ".bin").also { file ->
file.outputStream().use { os ->
os.write(idPrivateKey.bytes())
}
file.deleteOnExit()
}

init {
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID)
}
Expand All @@ -91,10 +104,7 @@ class GoInteropTest {
fun connect1() {
val logger = LogManager.getLogger("test")
val daemonLauncher = P2pdRunner().launcher()!!
val identityFileArgs = arrayOf<String>()

// uncomment the following line and set the generated (with p2p-keygen tool) key file path
// val identityFileArgs = arrayOf<String>("-id", "E:\\ws\\jvm-libp2p-minimal\\p2pd.key")
val identityFileArgs = arrayOf<String>("-id", skFile.absoluteFile.canonicalPath)

val pdHost = daemonLauncher
.launch(45555, "-pubsub", *identityFileArgs)
Expand Down Expand Up @@ -211,8 +221,9 @@ class GoInteropTest {
fun hostTest() = defer { d ->
val logger = LogManager.getLogger("test")
val daemonLauncher = P2pdRunner().launcher()!!
val identityFileArgs = arrayOf<String>("-id", skFile.absoluteFile.canonicalPath)
val pdHost = daemonLauncher
.launch(45555, "-pubsub")
.launch(45555, "-pubsub", *identityFileArgs)

d.defer {
println("Killing p2pd process")
Expand Down Expand Up @@ -258,7 +269,7 @@ class GoInteropTest {
host.start().get(5, TimeUnit.SECONDS)
println("Host started")

val connFuture = host.network.connect(PeerId.random(), Multiaddr("/ip4/127.0.0.1/tcp/45555"))
val connFuture = host.network.connect(daemonPeerId, Multiaddr("/ip4/127.0.0.1/tcp/45555"))

connFuture.thenAccept {
logger.info("Connection made")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class TwoGossipHostTestBase {

protected fun connect() {
val connect = host1.network
.connect(host2.peerId, Multiaddr.fromString("/ip4/127.0.0.1/tcp/40001"))
.connect(host2.peerId, Multiaddr.fromString("/ip4/127.0.0.1/tcp/40001/p2p/" + host2.peerId))
connect.get(10, TimeUnit.SECONDS)

waitFor { gossipConnected(router1) }
Expand Down
Loading