From a541748c32c5644aa8098c81e96e51940cb25eaa Mon Sep 17 00:00:00 2001 From: achingbrain Date: Mon, 3 Jul 2023 21:51:20 +0200 Subject: [PATCH] fix!: close streams gracefully - Refactors `.close`, `closeRead` and `.closeWrite` methods on the `Stream` interface to be async - The `Connection` interface now has `.close` and `.abort` methods - `.close` on `Stream`s and `Connection`s wait for the internal message queues to empty before closing - `.abort` on `Stream`s and `Connection`s close the underlying stream immediately and discards any unsent data - `@chainsafe/libp2p-yamux` now uses the `AbstractStream` class from `@libp2p/interface` the same as `@libp2p/mplex` and `@libp2p/webrtc` Follow-up PRs will be necessary to `@chainsafe/libp2p-yamux`, `@chainsafe/libp2p-gossipsub` and `@chainsafe/libp2p-noise` though they will not block the release as their code is temporarily added to this repo to let CI run. Fixes #1793 Fixes #656 BREAKING CHANGE: the `.close`, `closeRead` and `closeWrite` methods on the `Stream` interface are now asynchronous --- .github/workflows/main.yml | 3 +- doc/METRICS.md | 2 +- interop/BrowserDockerfile | 2 + interop/Dockerfile | 3 + interop/README.md | 79 ++- interop/package.json | 1 + interop/src/index.ts | 4 +- interop/test/ping.spec.ts | 20 +- .../connection-encryption-noise/package.json | 99 +++ .../src/@types/basic.ts | 5 + .../src/@types/handshake-interface.ts | 12 + .../src/@types/handshake.ts | 48 ++ .../src/@types/libp2p.ts | 10 + .../src/constants.ts | 4 + .../connection-encryption-noise/src/crypto.ts | 16 + .../src/crypto/js.ts | 58 ++ .../src/crypto/streaming.ts | 58 ++ .../src/encoder.ts | 81 +++ .../src/handshake-xx.ts | 177 ++++++ .../src/handshakes/abstract-handshake.ts | 180 ++++++ .../src/handshakes/xx.ts | 184 ++++++ .../connection-encryption-noise/src/index.ts | 10 + .../connection-encryption-noise/src/logger.ts | 51 ++ .../src/metrics.ts | 32 + .../connection-encryption-noise/src/noise.ts | 192 ++++++ .../connection-encryption-noise/src/nonce.ts | 49 ++ .../src/proto/payload.proto | 11 + .../src/proto/payload.ts | 148 +++++ .../connection-encryption-noise/src/utils.ts | 111 ++++ .../test/compliance.spec.ts | 11 + .../test/fixtures/peer.ts | 37 ++ .../test/handshakes/xx.spec.ts | 160 +++++ .../test/index.spec.ts | 54 ++ .../test/noise.spec.ts | 210 ++++++ .../connection-encryption-noise/test/utils.ts | 19 + .../test/xx-handshake.spec.ts | 126 ++++ .../connection-encryption-noise/tsconfig.json | 42 ++ .../interface-compliance-tests/package.json | 10 +- .../src/connection-encryption/utils/index.ts | 1 + .../src/connection/index.ts | 10 +- .../src/mocks/connection-manager.ts | 4 +- .../src/mocks/connection.ts | 57 +- .../src/mocks/multiaddr-connection.ts | 9 + .../src/mocks/muxer.ts | 352 +++-------- .../src/stream-muxer/close-test.ts | 124 +++- .../stream-muxer/fixtures/pb/message.proto | 7 + .../src/stream-muxer/fixtures/pb/message.ts | 87 +++ .../src/stream-muxer/spawner.ts | 5 +- .../src/stream-muxer/stress-test.ts | 2 +- .../src/transport/listen-test.ts | 2 +- .../src/connection-manager/index.ts | 13 +- packages/interface-internal/src/index.ts | 598 +----------------- packages/interface/package.json | 8 +- packages/interface/src/connection/index.ts | 134 +++- packages/interface/src/connection/status.ts | 4 - packages/interface/src/index.ts | 8 +- packages/interface/src/stream-muxer/index.ts | 9 +- packages/interface/src/stream-muxer/stream.ts | 461 ++++++++------ packages/kad-dht/src/network.ts | 4 +- packages/kad-dht/src/routing-table/index.ts | 2 +- packages/libp2p-daemon-server/package.json | 2 +- packages/libp2p/.aegir.js | 3 +- packages/libp2p/package.json | 4 +- .../libp2p/src/circuit-relay/server/index.ts | 34 +- .../src/circuit-relay/transport/index.ts | 66 +- .../src/circuit-relay/transport/listener.ts | 2 + .../transport/reservation-store.ts | 43 +- .../src/connection-manager/auto-dial.ts | 4 +- .../src/connection-manager/dial-queue.ts | 22 +- .../libp2p/src/connection-manager/index.ts | 10 +- packages/libp2p/src/connection/index.ts | 79 ++- packages/libp2p/src/fetch/index.ts | 8 +- packages/libp2p/src/identify/identify.ts | 120 ++-- packages/libp2p/src/libp2p.ts | 4 +- packages/libp2p/src/ping/index.ts | 51 +- packages/libp2p/src/transport-manager.ts | 4 + packages/libp2p/src/upgrader.ts | 27 +- .../libp2p/test/circuit-relay/hop.spec.ts | 10 +- .../libp2p/test/circuit-relay/relay.node.ts | 57 +- .../libp2p/test/circuit-relay/stop.spec.ts | 53 +- packages/libp2p/test/circuit-relay/utils.ts | 19 + .../test/connection-manager/direct.node.ts | 22 +- .../test/connection-manager/index.node.ts | 5 +- .../test/connection-manager/index.spec.ts | 7 +- .../libp2p/test/connection/compliance.spec.ts | 15 +- packages/libp2p/test/connection/index.spec.ts | 19 +- packages/libp2p/test/identify/index.spec.ts | 31 +- packages/libp2p/test/identify/service.node.ts | 2 +- .../test/transports/transport-manager.node.ts | 11 +- .../test/transports/transport-manager.spec.ts | 14 +- .../libp2p/test/upgrading/upgrader.spec.ts | 13 +- .../libp2p/test/upnp-nat/upnp-nat.node.ts | 8 +- packages/multistream-select/package.json | 2 +- packages/pubsub-gossipsub/package.json | 2 +- packages/pubsub-gossipsub/src/index.ts | 2 +- packages/pubsub/package.json | 2 +- packages/pubsub/src/peer-streams.ts | 7 +- packages/pubsub/test/utils/index.ts | 6 +- .../stream-multiplexer-mplex/package.json | 3 +- .../stream-multiplexer-mplex/src/mplex.ts | 191 +++--- .../stream-multiplexer-mplex/src/stream.ts | 49 +- .../test/mplex.spec.ts | 13 +- .../test/stream.spec.ts | 223 +++---- .../stream-multiplexer-yamux/package.json | 4 +- .../stream-multiplexer-yamux/src/muxer.ts | 94 +-- .../stream-multiplexer-yamux/src/stream.ts | 351 ++++------ .../test/bench/comparison.bench.ts | 2 +- .../test/muxer.spec.ts | 57 +- .../test/stream.spec.ts | 71 ++- .../stream-multiplexer-yamux/test/util.ts | 21 +- packages/transport-tcp/src/constants.ts | 2 +- packages/transport-tcp/src/socket-to-conn.ts | 98 ++- .../transport-tcp/test/socket-to-conn.spec.ts | 12 +- .../examples/browser-to-browser/package.json | 2 +- .../examples/browser-to-server/package.json | 2 +- packages/transport-webrtc/package.json | 4 +- packages/transport-webrtc/src/maconn.ts | 17 +- packages/transport-webrtc/src/muxer.ts | 10 +- .../src/private-to-private/handler.ts | 15 +- .../src/private-to-private/transport.ts | 10 +- packages/transport-webrtc/src/stream.ts | 35 +- .../test/peer.browser.spec.ts | 8 +- .../test/stream.browser.spec.ts | 22 +- packages/transport-webrtc/test/stream.spec.ts | 8 +- packages/transport-websockets/package.json | 1 - .../transport-websockets/src/constants.ts | 2 +- .../src/socket-to-conn.ts | 31 +- packages/transport-webtransport/src/index.ts | 103 ++- .../transport-webtransport/test/browser.ts | 6 +- packages/utils/package.json | 1 - packages/utils/src/stream-to-ma-conn.ts | 55 +- packages/utils/test/stream-to-ma-conn.spec.ts | 12 +- 132 files changed, 4337 insertions(+), 2208 deletions(-) create mode 100644 packages/connection-encryption-noise/package.json create mode 100644 packages/connection-encryption-noise/src/@types/basic.ts create mode 100644 packages/connection-encryption-noise/src/@types/handshake-interface.ts create mode 100644 packages/connection-encryption-noise/src/@types/handshake.ts create mode 100644 packages/connection-encryption-noise/src/@types/libp2p.ts create mode 100644 packages/connection-encryption-noise/src/constants.ts create mode 100644 packages/connection-encryption-noise/src/crypto.ts create mode 100644 packages/connection-encryption-noise/src/crypto/js.ts create mode 100644 packages/connection-encryption-noise/src/crypto/streaming.ts create mode 100644 packages/connection-encryption-noise/src/encoder.ts create mode 100644 packages/connection-encryption-noise/src/handshake-xx.ts create mode 100644 packages/connection-encryption-noise/src/handshakes/abstract-handshake.ts create mode 100644 packages/connection-encryption-noise/src/handshakes/xx.ts create mode 100644 packages/connection-encryption-noise/src/index.ts create mode 100644 packages/connection-encryption-noise/src/logger.ts create mode 100644 packages/connection-encryption-noise/src/metrics.ts create mode 100644 packages/connection-encryption-noise/src/noise.ts create mode 100644 packages/connection-encryption-noise/src/nonce.ts create mode 100644 packages/connection-encryption-noise/src/proto/payload.proto create mode 100644 packages/connection-encryption-noise/src/proto/payload.ts create mode 100644 packages/connection-encryption-noise/src/utils.ts create mode 100644 packages/connection-encryption-noise/test/compliance.spec.ts create mode 100644 packages/connection-encryption-noise/test/fixtures/peer.ts create mode 100644 packages/connection-encryption-noise/test/handshakes/xx.spec.ts create mode 100644 packages/connection-encryption-noise/test/index.spec.ts create mode 100644 packages/connection-encryption-noise/test/noise.spec.ts create mode 100644 packages/connection-encryption-noise/test/utils.ts create mode 100644 packages/connection-encryption-noise/test/xx-handshake.spec.ts create mode 100644 packages/connection-encryption-noise/tsconfig.json create mode 100644 packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.proto create mode 100644 packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.ts delete mode 100644 packages/interface/src/connection/status.ts diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b4fff9e2ca..1ad3954303 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -186,7 +186,7 @@ jobs: - uses: actions/checkout@v3 - uses: ipfs/aegir/actions/cache-node-modules@master - name: Build images - run: (cd interop && make) + run: (cd interop && make -j 4) - name: Save package-lock.json as artifact uses: actions/upload-artifact@v2 with: @@ -197,6 +197,7 @@ jobs: - uses: libp2p/test-plans/.github/actions/run-interop-ping-test@master with: test-filter: js-libp2p-head + test-ignore: nim extra-versions: ${{ github.workspace }}/interop/node-version.json ${{ github.workspace }}/interop/chromium-version.json ${{ github.workspace }}/interop/firefox-version.json s3-cache-bucket: ${{ vars.S3_LIBP2P_BUILD_CACHE_BUCKET_NAME }} s3-access-key-id: ${{ vars.S3_LIBP2P_BUILD_CACHE_AWS_ACCESS_KEY_ID }} diff --git a/doc/METRICS.md b/doc/METRICS.md index 93764e08ed..0fe2c8e01b 100644 --- a/doc/METRICS.md +++ b/doc/METRICS.md @@ -64,7 +64,7 @@ const node = await createLibp2p({ To define component metrics first get a reference to the metrics object: ```ts -import type { Metrics } from '@libp2p/interface-metrics' +import type { Metrics } from '@libp2p/interface/metrics' interface MyClassComponents { metrics: Metrics diff --git a/interop/BrowserDockerfile b/interop/BrowserDockerfile index b97304cb9d..b235693b02 100644 --- a/interop/BrowserDockerfile +++ b/interop/BrowserDockerfile @@ -7,5 +7,7 @@ WORKDIR /app/interop RUN npx playwright install ARG BROWSER=chromium # Options: chromium, firefox, webkit ENV BROWSER=$BROWSER +# disable colored output and CLI animation from test runners +ENV CI true ENTRYPOINT npm run test:interop:multidim -- --build false --types false -t browser -- --browser $BROWSER diff --git a/interop/Dockerfile b/interop/Dockerfile index e95bfc4aec..1f77807abc 100644 --- a/interop/Dockerfile +++ b/interop/Dockerfile @@ -8,4 +8,7 @@ COPY ./interop ./interop WORKDIR /app/interop +# disable colored output and CLI animation from test runners +ENV CI true + ENTRYPOINT [ "npm", "run", "test:interop:multidim", "--", "--build", "false", "--types", "false", "-t", "node" ] diff --git a/interop/README.md b/interop/README.md index 114e8ce543..bf2bf9a1a7 100644 --- a/interop/README.md +++ b/interop/README.md @@ -9,14 +9,87 @@ ## Table of contents -- [Install](#install) +- [Usage](#usage) + - [Build js-libp2p](#build-js-libp2p) + - [node.js](#nodejs) + - [Browsers](#browsers) + - [Build another libp2p implementation](#build-another-libp2p-implementation) + - [Running Redis](#running-redis) + - [Start libp2p](#start-libp2p) + - [Start another libp2p implementation](#start-another-libp2p-implementation) - [License](#license) - [Contribution](#contribution) -## Install +## Usage + +The multidim interop tests use random high ports for listeners. Since you need to know which port will be listened on ahead of time to `EXPOSE` a port in a Docker image to the host machine, this means everything has to be run in Docker. + +### Build js-libp2p + +This must be repeated every time you make a change to the js-libp2p source code. + +#### node.js + +```console +$ npm run build +$ docker build . -f ./interop/Dockerfile -t js-libp2p-node +``` + +#### Browsers + +```console +$ npm run build +$ docker build . -f ./interop/BrowserDockerfile -t js-libp2p-browsers +``` + +### Build another libp2p implementation + +1. Clone the test-plans repo somewhere + ```console + $ git clone https://github.com/libp2p/test-plans.git + ``` +2. (Optional) If you are running an M1 Mac you may need to override the build platform. + - Edit `/multidim-interop/dockerBuildWrapper.sh` + - Add `--platform linux/arm64/v8` to the `docker buildx build` command + ``` + docker buildx build \ + --platform linux/arm64/v8 \ <-- add this line + --load \ + -t $IMAGE_NAME $CACHING_OPTIONS "$@" + ``` +3. (Optional) Enable some sort of debug output + - nim-libp2p + - edit `/multidim-interop/impl/nim/$VERSION/Dockerfile` + - Change `-d:chronicles_log_level=WARN` to `-d:chronicles_log_level=DEBUG` + - rust-libp2p + - When starting the docker container add `-e RUST_LOG=debug` + - go-libp2p + - When starting the docker container add `-e GOLOG_LOG_LEVEL=debug` +4. Build the version you want to test against + ```console + $ cd impl/$IMPL/$VERSION + $ make + ... + ``` + +### Running Redis + +Redis is used to allow inter-container communication, exchanging listen addresses etc. It must be started as a Docker container: + +```console +$ docker run --name redis --rm -p 6379:6379 redis:7-alpine +``` + +### Start libp2p + +```console +$ docker run -e transport=tcp -e muxer=yamux -e security=noise -e is_dialer=true -e redis_addr=redis:6379 --link redis:redis js-libp2p-node +``` + +### Start another libp2p implementation ```console -$ npm i multidim-interop +$ docker run -e transport=tcp -e muxer=yamux -e security=noise -e is_dialer=false -e redis_addr=redis:6379 --link redis:redis nim-v1.0 ``` ## License diff --git a/interop/package.json b/interop/package.json index 9a54387b42..3942d0e123 100644 --- a/interop/package.json +++ b/interop/package.json @@ -35,6 +35,7 @@ "scripts": { "start": "node index.js", "build": "aegir build", + "lint": "aegir lint", "test:interop:multidim": "aegir test" }, "dependencies": { diff --git a/interop/src/index.ts b/interop/src/index.ts index 17adc4ed14..c7522ec635 100644 --- a/interop/src/index.ts +++ b/interop/src/index.ts @@ -1,3 +1,3 @@ -console.log("Everything is defined in the test folder") +// Everything is defined in the test folder -export { } \ No newline at end of file +export { } diff --git a/interop/test/ping.spec.ts b/interop/test/ping.spec.ts index f182f1a66f..75cf38ca0d 100644 --- a/interop/test/ping.spec.ts +++ b/interop/test/ping.spec.ts @@ -13,7 +13,7 @@ import { webTransport } from '@libp2p/webtransport' import { type Multiaddr, multiaddr } from '@multiformats/multiaddr' import { createLibp2p, type Libp2p, type Libp2pOptions } from 'libp2p' import { circuitRelayTransport } from 'libp2p/circuit-relay' -import { IdentifyService, identifyService } from 'libp2p/identify' +import { type IdentifyService, identifyService } from 'libp2p/identify' import { pingService, type PingService } from 'libp2p/ping' async function redisProxy (commands: any[]): Promise { @@ -28,7 +28,10 @@ let node: Libp2p<{ ping: PingService, identify: IdentifyService }> const isDialer: boolean = process.env.is_dialer === 'true' const timeoutSecs: string = process.env.test_timeout_secs ?? '180' -describe('ping test', () => { +describe('ping test', function () { + // make the default timeout longer than the listener timeout + this.timeout((parseInt(timeoutSecs) * 1000) + 30000) + // eslint-disable-next-line complexity beforeEach(async () => { // Setup libp2p node @@ -39,6 +42,9 @@ describe('ping test', () => { const options: Libp2pOptions<{ ping: PingService, identify: IdentifyService }> = { start: true, + connectionManager: { + minConnections: 0 + }, connectionGater: { denyDialMultiaddr: async () => false }, @@ -208,14 +214,18 @@ describe('ping test', () => { // eslint-disable-next-line complexity (isDialer ? it : it.skip)('should dial and ping', async () => { try { - let otherMa: string = (await redisProxy(['BLPOP', 'listenerAddr', timeoutSecs]).catch(err => { throw new Error(`Failed to wait for listener: ${err}`) }))[1] + let otherMaStr: string = (await redisProxy(['BLPOP', 'listenerAddr', timeoutSecs]).catch(err => { throw new Error(`Failed to wait for listener: ${err}`) }))[1] // Hack until these are merged: // - https://github.com/multiformats/js-multiaddr-to-uri/pull/120 - otherMa = otherMa.replace('/tls/ws', '/wss') + otherMaStr = otherMaStr.replace('/tls/ws', '/wss') + + const otherMa = multiaddr(otherMaStr) console.error(`node ${node.peerId.toString()} pings: ${otherMa}`) const handshakeStartInstant = Date.now() - await node.dial(multiaddr(otherMa)) + + await node.dial(otherMa) + const pingRTT = await node.services.ping.ping(multiaddr(otherMa)) const handshakePlusOneRTT = Date.now() - handshakeStartInstant console.log(JSON.stringify({ diff --git a/packages/connection-encryption-noise/package.json b/packages/connection-encryption-noise/package.json new file mode 100644 index 0000000000..49a10fe7ce --- /dev/null +++ b/packages/connection-encryption-noise/package.json @@ -0,0 +1,99 @@ +{ + "name": "@chainsafe/libp2p-noise", + "version": "12.0.1", + "author": "ChainSafe ", + "license": "Apache-2.0 OR MIT", + "homepage": "https://github.com/ChainSafe/js-libp2p-noise#readme", + "repository": { + "type": "git", + "url": "git+https://github.com/ChainSafe/js-libp2p-noise.git" + }, + "bugs": { + "url": "https://github.com/ChainSafe/js-libp2p-noise/issues" + }, + "keywords": [ + "crypto", + "libp2p", + "noise" + ], + "engines": { + "node": ">=16.0.0", + "npm": ">=7.0.0" + }, + "type": "module", + "types": "./dist/src/index.d.ts", + "files": [ + "src", + "dist", + "!dist/test", + "!**/*.tsbuildinfo" + ], + "exports": { + ".": { + "types": "./dist/src/index.d.ts", + "import": "./dist/src/index.js" + } + }, + "eslintConfig": { + "extends": "ipfs", + "parserOptions": { + "sourceType": "module" + }, + "rules": { + "@typescript-eslint/no-unused-vars": "error", + "@typescript-eslint/explicit-function-return-type": "warn", + "@typescript-eslint/strict-boolean-expressions": "off" + }, + "ignorePatterns": [ + "src/proto/payload.js", + "src/proto/payload.d.ts", + "test/fixtures/node-globals.js" + ] + }, + "scripts": { + "bench": "node benchmarks/benchmark.js", + "clean": "aegir clean", + "dep-check": "aegir dep-check", + "build": "aegir build", + "lint": "aegir lint", + "lint:fix": "aegir lint --fix", + "test": "aegir test", + "test:node": "aegir test -t node", + "test:browser": "aegir test -t browser -t webworker", + "test:electron-main": "aegir test -t electron-main", + "docs": "aegir docs", + "proto:gen": "protons ./src/proto/payload.proto", + "prepublish": "npm run build" + }, + "dependencies": { + "@libp2p/crypto": "^1.0.11", + "@libp2p/interface": "~0.0.1", + "@libp2p/logger": "^2.1.1", + "@libp2p/peer-id": "^2.0.0", + "@stablelib/chacha20poly1305": "^1.0.1", + "@noble/hashes": "^1.3.0", + "@stablelib/x25519": "^1.0.3", + "it-length-prefixed": "^9.0.1", + "it-length-prefixed-stream": "^1.0.0", + "it-byte-stream": "^1.0.0", + "it-pair": "^2.0.2", + "it-pipe": "^3.0.1", + "it-stream-types": "^2.0.1", + "protons-runtime": "^5.0.0", + "uint8arraylist": "^2.3.2", + "uint8arrays": "^4.0.2" + }, + "devDependencies": { + "@libp2p/interface-compliance-tests": "^3.0.0", + "@libp2p/peer-id-factory": "^2.0.0", + "@types/sinon": "^10.0.14", + "aegir": "^39.0.5", + "iso-random-stream": "^2.0.2", + "protons": "^7.0.0", + "sinon": "^15.0.0" + }, + "browser": { + "./dist/src/alloc-unsafe.js": "./dist/src/alloc-unsafe-browser.js", + "util": false + } +} diff --git a/packages/connection-encryption-noise/src/@types/basic.ts b/packages/connection-encryption-noise/src/@types/basic.ts new file mode 100644 index 0000000000..364d1f89ee --- /dev/null +++ b/packages/connection-encryption-noise/src/@types/basic.ts @@ -0,0 +1,5 @@ +export type bytes = Uint8Array +export type bytes32 = Uint8Array +export type bytes16 = Uint8Array + +export type uint64 = number diff --git a/packages/connection-encryption-noise/src/@types/handshake-interface.ts b/packages/connection-encryption-noise/src/@types/handshake-interface.ts new file mode 100644 index 0000000000..9b402b1fd3 --- /dev/null +++ b/packages/connection-encryption-noise/src/@types/handshake-interface.ts @@ -0,0 +1,12 @@ +import type { bytes } from './basic.js' +import type { NoiseSession } from './handshake.js' +import type { NoiseExtensions } from '../proto/payload.js' +import type { PeerId } from '@libp2p/interface/peer-id' + +export interface IHandshake { + session: NoiseSession + remotePeer: PeerId + remoteExtensions: NoiseExtensions + encrypt: (plaintext: bytes, session: NoiseSession) => bytes + decrypt: (ciphertext: bytes, session: NoiseSession, dst?: Uint8Array) => { plaintext: bytes, valid: boolean } +} diff --git a/packages/connection-encryption-noise/src/@types/handshake.ts b/packages/connection-encryption-noise/src/@types/handshake.ts new file mode 100644 index 0000000000..ec333b703b --- /dev/null +++ b/packages/connection-encryption-noise/src/@types/handshake.ts @@ -0,0 +1,48 @@ +import type { bytes, bytes32, uint64 } from './basic.js' +import type { KeyPair } from './libp2p.js' +import type { Nonce } from '../nonce.js' + +export type Hkdf = [bytes, bytes, bytes] + +export interface MessageBuffer { + ne: bytes32 + ns: bytes + ciphertext: bytes +} + +export interface CipherState { + k: bytes32 + // For performance reasons, the nonce is represented as a Nonce object + // The nonce is treated as a uint64, even though the underlying `number` only has 52 safely-available bits. + n: Nonce +} + +export interface SymmetricState { + cs: CipherState + ck: bytes32 // chaining key + h: bytes32 // handshake hash +} + +export interface HandshakeState { + ss: SymmetricState + s: KeyPair + e?: KeyPair + rs: bytes32 + re: bytes32 + psk: bytes32 +} + +export interface NoiseSession { + hs: HandshakeState + h?: bytes32 + cs1?: CipherState + cs2?: CipherState + mc: uint64 + i: boolean +} + +export interface INoisePayload { + identityKey: bytes + identitySig: bytes + data: bytes +} diff --git a/packages/connection-encryption-noise/src/@types/libp2p.ts b/packages/connection-encryption-noise/src/@types/libp2p.ts new file mode 100644 index 0000000000..c20fe93952 --- /dev/null +++ b/packages/connection-encryption-noise/src/@types/libp2p.ts @@ -0,0 +1,10 @@ +import type { bytes32 } from './basic.js' +import type { NoiseExtensions } from '../proto/payload.js' +import type { ConnectionEncrypter } from '@libp2p/interface/connection-encrypter' + +export interface KeyPair { + publicKey: bytes32 + privateKey: bytes32 +} + +export interface INoiseConnection extends ConnectionEncrypter {} diff --git a/packages/connection-encryption-noise/src/constants.ts b/packages/connection-encryption-noise/src/constants.ts new file mode 100644 index 0000000000..7e8105c47b --- /dev/null +++ b/packages/connection-encryption-noise/src/constants.ts @@ -0,0 +1,4 @@ +export const NOISE_MSG_MAX_LENGTH_BYTES = 65535 +export const NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG = NOISE_MSG_MAX_LENGTH_BYTES - 16 + +export const DUMP_SESSION_KEYS = Boolean(globalThis.process?.env?.DUMP_SESSION_KEYS) diff --git a/packages/connection-encryption-noise/src/crypto.ts b/packages/connection-encryption-noise/src/crypto.ts new file mode 100644 index 0000000000..108dfee1c6 --- /dev/null +++ b/packages/connection-encryption-noise/src/crypto.ts @@ -0,0 +1,16 @@ +import type { bytes32, bytes } from './@types/basic.js' +import type { Hkdf } from './@types/handshake.js' +import type { KeyPair } from './@types/libp2p.js' + +export interface ICryptoInterface { + hashSHA256: (data: Uint8Array) => Uint8Array + + getHKDF: (ck: bytes32, ikm: Uint8Array) => Hkdf + + generateX25519KeyPair: () => KeyPair + generateX25519KeyPairFromSeed: (seed: Uint8Array) => KeyPair + generateX25519SharedKey: (privateKey: Uint8Array, publicKey: Uint8Array) => Uint8Array + + chaCha20Poly1305Encrypt: (plaintext: Uint8Array, nonce: Uint8Array, ad: Uint8Array, k: bytes32) => bytes + chaCha20Poly1305Decrypt: (ciphertext: Uint8Array, nonce: Uint8Array, ad: Uint8Array, k: bytes32, dst?: Uint8Array) => bytes | null +} diff --git a/packages/connection-encryption-noise/src/crypto/js.ts b/packages/connection-encryption-noise/src/crypto/js.ts new file mode 100644 index 0000000000..2a50102128 --- /dev/null +++ b/packages/connection-encryption-noise/src/crypto/js.ts @@ -0,0 +1,58 @@ +import { hkdf } from '@noble/hashes/hkdf' +import { sha256 } from '@noble/hashes/sha256' +import { ChaCha20Poly1305 } from '@stablelib/chacha20poly1305' +import * as x25519 from '@stablelib/x25519' +import type { bytes, bytes32 } from '../@types/basic.js' +import type { Hkdf } from '../@types/handshake.js' +import type { KeyPair } from '../@types/libp2p.js' +import type { ICryptoInterface } from '../crypto.js' + +export const pureJsCrypto: ICryptoInterface = { + hashSHA256 (data: Uint8Array): Uint8Array { + return sha256(data) + }, + + getHKDF (ck: bytes32, ikm: Uint8Array): Hkdf { + const okm = hkdf(sha256, ikm, ck, undefined, 96) + + const k1 = okm.subarray(0, 32) + const k2 = okm.subarray(32, 64) + const k3 = okm.subarray(64, 96) + + return [k1, k2, k3] + }, + + generateX25519KeyPair (): KeyPair { + const keypair = x25519.generateKeyPair() + + return { + publicKey: keypair.publicKey, + privateKey: keypair.secretKey + } + }, + + generateX25519KeyPairFromSeed (seed: Uint8Array): KeyPair { + const keypair = x25519.generateKeyPairFromSeed(seed) + + return { + publicKey: keypair.publicKey, + privateKey: keypair.secretKey + } + }, + + generateX25519SharedKey (privateKey: Uint8Array, publicKey: Uint8Array): Uint8Array { + return x25519.sharedKey(privateKey, publicKey) + }, + + chaCha20Poly1305Encrypt (plaintext: Uint8Array, nonce: Uint8Array, ad: Uint8Array, k: bytes32): bytes { + const ctx = new ChaCha20Poly1305(k) + + return ctx.seal(nonce, plaintext, ad) + }, + + chaCha20Poly1305Decrypt (ciphertext: Uint8Array, nonce: Uint8Array, ad: Uint8Array, k: bytes32, dst?: Uint8Array): bytes | null { + const ctx = new ChaCha20Poly1305(k) + + return ctx.open(nonce, ciphertext, ad, dst) + } +} diff --git a/packages/connection-encryption-noise/src/crypto/streaming.ts b/packages/connection-encryption-noise/src/crypto/streaming.ts new file mode 100644 index 0000000000..e785a649fd --- /dev/null +++ b/packages/connection-encryption-noise/src/crypto/streaming.ts @@ -0,0 +1,58 @@ +import { TAG_LENGTH } from '@stablelib/chacha20poly1305' +import { NOISE_MSG_MAX_LENGTH_BYTES, NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG } from '../constants.js' +import { uint16BEEncode } from '../encoder.js' +import type { IHandshake } from '../@types/handshake-interface.js' +import type { MetricsRegistry } from '../metrics.js' +import type { Transform } from 'it-stream-types' +import type { Uint8ArrayList } from 'uint8arraylist' + +// Returns generator that encrypts payload from the user +export function encryptStream (handshake: IHandshake, metrics?: MetricsRegistry): Transform> { + return async function * (source) { + for await (const chunk of source) { + for (let i = 0; i < chunk.length; i += NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG) { + let end = i + NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG + if (end > chunk.length) { + end = chunk.length + } + + const data = handshake.encrypt(chunk.subarray(i, end), handshake.session) + metrics?.encryptedPackets.increment() + + yield uint16BEEncode(data.byteLength) + yield data + } + } + } +} + +// Decrypt received payload to the user +export function decryptStream (handshake: IHandshake, metrics?: MetricsRegistry): Transform, AsyncIterable> { + return async function * (source) { + for await (const chunk of source) { + for (let i = 0; i < chunk.length; i += NOISE_MSG_MAX_LENGTH_BYTES) { + let end = i + NOISE_MSG_MAX_LENGTH_BYTES + if (end > chunk.length) { + end = chunk.length + } + + if (end - TAG_LENGTH < i) { + throw new Error('Invalid chunk') + } + const encrypted = chunk.subarray(i, end) + // memory allocation is not cheap so reuse the encrypted Uint8Array + // see https://github.com/ChainSafe/js-libp2p-noise/pull/242#issue-1422126164 + // this is ok because chacha20 reads bytes one by one and don't reread after that + // it's also tested in https://github.com/ChainSafe/as-chacha20poly1305/pull/1/files#diff-25252846b58979dcaf4e41d47b3eadd7e4f335e7fb98da6c049b1f9cd011f381R48 + const dst = chunk.subarray(i, end - TAG_LENGTH) + const { plaintext: decrypted, valid } = handshake.decrypt(encrypted, handshake.session, dst) + if (!valid) { + metrics?.decryptErrors.increment() + throw new Error('Failed to validate decrypted chunk') + } + metrics?.decryptedPackets.increment() + yield decrypted + } + } + } +} diff --git a/packages/connection-encryption-noise/src/encoder.ts b/packages/connection-encryption-noise/src/encoder.ts new file mode 100644 index 0000000000..f1f963f3a7 --- /dev/null +++ b/packages/connection-encryption-noise/src/encoder.ts @@ -0,0 +1,81 @@ +import { concat as uint8ArrayConcat } from 'uint8arrays/concat' +import type { bytes } from './@types/basic.js' +import type { MessageBuffer } from './@types/handshake.js' +import type { LengthDecoderFunction } from 'it-length-prefixed' +import type { Uint8ArrayList } from 'uint8arraylist' + +const allocUnsafe = (len: number): Uint8Array => { + if (globalThis.Buffer) { + return globalThis.Buffer.allocUnsafe(len) + } + + return new Uint8Array(len) +} + +export const uint16BEEncode = (value: number): Uint8Array => { + const target = allocUnsafe(2) + new DataView(target.buffer, target.byteOffset, target.byteLength).setUint16(0, value, false) + return target +} +uint16BEEncode.bytes = 2 + +export const uint16BEDecode: LengthDecoderFunction = (data: Uint8Array | Uint8ArrayList): number => { + if (data.length < 2) throw RangeError('Could not decode int16BE') + + if (data instanceof Uint8Array) { + return new DataView(data.buffer, data.byteOffset, data.byteLength).getUint16(0, false) + } + + return data.getUint16(0) +} +uint16BEDecode.bytes = 2 + +// Note: IK and XX encoder usage is opposite (XX uses in stages encode0 where IK uses encode1) + +export function encode0 (message: MessageBuffer): bytes { + return uint8ArrayConcat([message.ne, message.ciphertext], message.ne.length + message.ciphertext.length) +} + +export function encode1 (message: MessageBuffer): bytes { + return uint8ArrayConcat([message.ne, message.ns, message.ciphertext], message.ne.length + message.ns.length + message.ciphertext.length) +} + +export function encode2 (message: MessageBuffer): bytes { + return uint8ArrayConcat([message.ns, message.ciphertext], message.ns.length + message.ciphertext.length) +} + +export function decode0 (input: bytes): MessageBuffer { + if (input.length < 32) { + throw new Error('Cannot decode stage 0 MessageBuffer: length less than 32 bytes.') + } + + return { + ne: input.subarray(0, 32), + ciphertext: input.subarray(32, input.length), + ns: new Uint8Array(0) + } +} + +export function decode1 (input: bytes): MessageBuffer { + if (input.length < 80) { + throw new Error('Cannot decode stage 1 MessageBuffer: length less than 80 bytes.') + } + + return { + ne: input.subarray(0, 32), + ns: input.subarray(32, 80), + ciphertext: input.subarray(80, input.length) + } +} + +export function decode2 (input: bytes): MessageBuffer { + if (input.length < 48) { + throw new Error('Cannot decode stage 2 MessageBuffer: length less than 48 bytes.') + } + + return { + ne: new Uint8Array(0), + ns: input.subarray(0, 48), + ciphertext: input.subarray(48, input.length) + } +} diff --git a/packages/connection-encryption-noise/src/handshake-xx.ts b/packages/connection-encryption-noise/src/handshake-xx.ts new file mode 100644 index 0000000000..92a6cf6e01 --- /dev/null +++ b/packages/connection-encryption-noise/src/handshake-xx.ts @@ -0,0 +1,177 @@ +import { InvalidCryptoExchangeError, UnexpectedPeerError } from '@libp2p/interface/errors' +import { decode0, decode1, decode2, encode0, encode1, encode2 } from './encoder.js' +import { XX } from './handshakes/xx.js' +import { + logger, + logLocalStaticKeys, + logLocalEphemeralKeys, + logRemoteEphemeralKey, + logRemoteStaticKey, + logCipherState +} from './logger.js' +import { + decodePayload, + getPeerIdFromPayload, + verifySignedPayload +} from './utils.js' +import type { bytes, bytes32 } from './@types/basic.js' +import type { IHandshake } from './@types/handshake-interface.js' +import type { CipherState, NoiseSession } from './@types/handshake.js' +import type { KeyPair } from './@types/libp2p.js' +import type { ICryptoInterface } from './crypto.js' +import type { NoiseExtensions } from './proto/payload.js' +import type { PeerId } from '@libp2p/interface/peer-id' +import type { LengthPrefixedStream } from 'it-length-prefixed-stream' + +export class XXHandshake implements IHandshake { + public isInitiator: boolean + public session: NoiseSession + public remotePeer!: PeerId + public remoteExtensions: NoiseExtensions = { webtransportCerthashes: [] } + + protected payload: bytes + protected connection: LengthPrefixedStream + protected xx: XX + protected staticKeypair: KeyPair + + private readonly prologue: bytes32 + + constructor ( + isInitiator: boolean, + payload: bytes, + prologue: bytes32, + crypto: ICryptoInterface, + staticKeypair: KeyPair, + connection: LengthPrefixedStream, + remotePeer?: PeerId, + handshake?: XX + ) { + this.isInitiator = isInitiator + this.payload = payload + this.prologue = prologue + this.staticKeypair = staticKeypair + this.connection = connection + if (remotePeer) { + this.remotePeer = remotePeer + } + this.xx = handshake ?? new XX(crypto) + this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair) + } + + // stage 0 + public async propose (): Promise { + logLocalStaticKeys(this.session.hs.s) + if (this.isInitiator) { + logger.trace('Stage 0 - Initiator starting to send first message.') + const messageBuffer = this.xx.sendMessage(this.session, new Uint8Array(0)) + await this.connection.write(encode0(messageBuffer)) + logger.trace('Stage 0 - Initiator finished sending first message.') + logLocalEphemeralKeys(this.session.hs.e) + } else { + logger.trace('Stage 0 - Responder waiting to receive first message...') + const receivedMessageBuffer = decode0((await this.connection.read()).subarray()) + const { valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) + if (!valid) { + throw new InvalidCryptoExchangeError('xx handshake stage 0 validation fail') + } + logger.trace('Stage 0 - Responder received first message.') + logRemoteEphemeralKey(this.session.hs.re) + } + } + + // stage 1 + public async exchange (): Promise { + if (this.isInitiator) { + logger.trace('Stage 1 - Initiator waiting to receive first message from responder...') + const receivedMessageBuffer = decode1((await this.connection.read()).subarray()) + const { plaintext, valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) + if (!valid) { + throw new InvalidCryptoExchangeError('xx handshake stage 1 validation fail') + } + logger.trace('Stage 1 - Initiator received the message.') + logRemoteEphemeralKey(this.session.hs.re) + logRemoteStaticKey(this.session.hs.rs) + + logger.trace("Initiator going to check remote's signature...") + try { + const decodedPayload = decodePayload(plaintext) + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) + this.setRemoteNoiseExtension(decodedPayload.extensions) + } catch (e) { + const err = e as Error + throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) + } + logger.trace('All good with the signature!') + } else { + logger.trace('Stage 1 - Responder sending out first message with signed payload and static key.') + const messageBuffer = this.xx.sendMessage(this.session, this.payload) + await this.connection.write(encode1(messageBuffer)) + logger.trace('Stage 1 - Responder sent the second handshake message with signed payload.') + logLocalEphemeralKeys(this.session.hs.e) + } + } + + // stage 2 + public async finish (): Promise { + if (this.isInitiator) { + logger.trace('Stage 2 - Initiator sending third handshake message.') + const messageBuffer = this.xx.sendMessage(this.session, this.payload) + await this.connection.write(encode2(messageBuffer)) + logger.trace('Stage 2 - Initiator sent message with signed payload.') + } else { + logger.trace('Stage 2 - Responder waiting for third handshake message...') + const receivedMessageBuffer = decode2((await this.connection.read()).subarray()) + const { plaintext, valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) + if (!valid) { + throw new InvalidCryptoExchangeError('xx handshake stage 2 validation fail') + } + logger.trace('Stage 2 - Responder received the message, finished handshake.') + + try { + const decodedPayload = decodePayload(plaintext) + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) + this.setRemoteNoiseExtension(decodedPayload.extensions) + } catch (e) { + const err = e as Error + throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) + } + } + logCipherState(this.session) + } + + public encrypt (plaintext: Uint8Array, session: NoiseSession): bytes { + const cs = this.getCS(session) + + return this.xx.encryptWithAd(cs, new Uint8Array(0), plaintext) + } + + public decrypt (ciphertext: Uint8Array, session: NoiseSession, dst?: Uint8Array): { plaintext: bytes, valid: boolean } { + const cs = this.getCS(session, false) + + return this.xx.decryptWithAd(cs, new Uint8Array(0), ciphertext, dst) + } + + public getRemoteStaticKey (): bytes { + return this.session.hs.rs + } + + private getCS (session: NoiseSession, encryption = true): CipherState { + if (!session.cs1 || !session.cs2) { + throw new InvalidCryptoExchangeError('Handshake not completed properly, cipher state does not exist.') + } + + if (this.isInitiator) { + return encryption ? session.cs1 : session.cs2 + } else { + return encryption ? session.cs2 : session.cs1 + } + } + + protected setRemoteNoiseExtension (e: NoiseExtensions | null | undefined): void { + if (e) { + this.remoteExtensions = e + } + } +} diff --git a/packages/connection-encryption-noise/src/handshakes/abstract-handshake.ts b/packages/connection-encryption-noise/src/handshakes/abstract-handshake.ts new file mode 100644 index 0000000000..714f8141fc --- /dev/null +++ b/packages/connection-encryption-noise/src/handshakes/abstract-handshake.ts @@ -0,0 +1,180 @@ +import { fromString as uint8ArrayFromString } from 'uint8arrays' +import { concat as uint8ArrayConcat } from 'uint8arrays/concat' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { logger } from '../logger.js' +import { Nonce } from '../nonce.js' +import type { bytes, bytes32 } from '../@types/basic.js' +import type { CipherState, MessageBuffer, SymmetricState } from '../@types/handshake.js' +import type { ICryptoInterface } from '../crypto.js' + +export interface DecryptedResult { + plaintext: bytes + valid: boolean +} + +export interface SplitState { + cs1: CipherState + cs2: CipherState +} + +export abstract class AbstractHandshake { + public crypto: ICryptoInterface + + constructor (crypto: ICryptoInterface) { + this.crypto = crypto + } + + public encryptWithAd (cs: CipherState, ad: Uint8Array, plaintext: Uint8Array): bytes { + const e = this.encrypt(cs.k, cs.n, ad, plaintext) + cs.n.increment() + + return e + } + + public decryptWithAd (cs: CipherState, ad: Uint8Array, ciphertext: Uint8Array, dst?: Uint8Array): DecryptedResult { + const { plaintext, valid } = this.decrypt(cs.k, cs.n, ad, ciphertext, dst) + if (valid) cs.n.increment() + + return { plaintext, valid } + } + + // Cipher state related + protected hasKey (cs: CipherState): boolean { + return !this.isEmptyKey(cs.k) + } + + protected createEmptyKey (): bytes32 { + return new Uint8Array(32) + } + + protected isEmptyKey (k: bytes32): boolean { + const emptyKey = this.createEmptyKey() + return uint8ArrayEquals(emptyKey, k) + } + + protected encrypt (k: bytes32, n: Nonce, ad: Uint8Array, plaintext: Uint8Array): bytes { + n.assertValue() + + return this.crypto.chaCha20Poly1305Encrypt(plaintext, n.getBytes(), ad, k) + } + + protected encryptAndHash (ss: SymmetricState, plaintext: bytes): bytes { + let ciphertext + if (this.hasKey(ss.cs)) { + ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext) + } else { + ciphertext = plaintext + } + + this.mixHash(ss, ciphertext) + return ciphertext + } + + protected decrypt (k: bytes32, n: Nonce, ad: bytes, ciphertext: bytes, dst?: Uint8Array): DecryptedResult { + n.assertValue() + + const encryptedMessage = this.crypto.chaCha20Poly1305Decrypt(ciphertext, n.getBytes(), ad, k, dst) + + if (encryptedMessage) { + return { + plaintext: encryptedMessage, + valid: true + } + } else { + return { + plaintext: new Uint8Array(0), + valid: false + } + } + } + + protected decryptAndHash (ss: SymmetricState, ciphertext: bytes): DecryptedResult { + let plaintext: bytes; let valid = true + if (this.hasKey(ss.cs)) { + ({ plaintext, valid } = this.decryptWithAd(ss.cs, ss.h, ciphertext)) + } else { + plaintext = ciphertext + } + + this.mixHash(ss, ciphertext) + return { plaintext, valid } + } + + protected dh (privateKey: bytes32, publicKey: bytes32): bytes32 { + try { + const derivedU8 = this.crypto.generateX25519SharedKey(privateKey, publicKey) + + if (derivedU8.length === 32) { + return derivedU8 + } + + return derivedU8.subarray(0, 32) + } catch (e) { + const err = e as Error + logger.error(err) + return new Uint8Array(32) + } + } + + protected mixHash (ss: SymmetricState, data: bytes): void { + ss.h = this.getHash(ss.h, data) + } + + protected getHash (a: Uint8Array, b: Uint8Array): bytes32 { + const u = this.crypto.hashSHA256(uint8ArrayConcat([a, b], a.length + b.length)) + return u + } + + protected mixKey (ss: SymmetricState, ikm: bytes32): void { + const [ck, tempK] = this.crypto.getHKDF(ss.ck, ikm) + ss.cs = this.initializeKey(tempK) + ss.ck = ck + } + + protected initializeKey (k: bytes32): CipherState { + return { k, n: new Nonce() } + } + + // Symmetric state related + + protected initializeSymmetric (protocolName: string): SymmetricState { + const protocolNameBytes = uint8ArrayFromString(protocolName, 'utf-8') + const h = this.hashProtocolName(protocolNameBytes) + + const ck = h + const key = this.createEmptyKey() + const cs: CipherState = this.initializeKey(key) + + return { cs, ck, h } + } + + protected hashProtocolName (protocolName: Uint8Array): bytes32 { + if (protocolName.length <= 32) { + const h = new Uint8Array(32) + h.set(protocolName) + return h + } else { + return this.getHash(protocolName, new Uint8Array(0)) + } + } + + protected split (ss: SymmetricState): SplitState { + const [tempk1, tempk2] = this.crypto.getHKDF(ss.ck, new Uint8Array(0)) + const cs1 = this.initializeKey(tempk1) + const cs2 = this.initializeKey(tempk2) + + return { cs1, cs2 } + } + + protected writeMessageRegular (cs: CipherState, payload: bytes): MessageBuffer { + const ciphertext = this.encryptWithAd(cs, new Uint8Array(0), payload) + const ne = this.createEmptyKey() + const ns = new Uint8Array(0) + + return { ne, ns, ciphertext } + } + + protected readMessageRegular (cs: CipherState, message: MessageBuffer): DecryptedResult { + return this.decryptWithAd(cs, new Uint8Array(0), message.ciphertext) + } +} diff --git a/packages/connection-encryption-noise/src/handshakes/xx.ts b/packages/connection-encryption-noise/src/handshakes/xx.ts new file mode 100644 index 0000000000..44d26fa4c0 --- /dev/null +++ b/packages/connection-encryption-noise/src/handshakes/xx.ts @@ -0,0 +1,184 @@ +import { isValidPublicKey } from '../utils.js' +import { AbstractHandshake, type DecryptedResult } from './abstract-handshake.js' +import type { bytes32, bytes } from '../@types/basic.js' +import type { CipherState, HandshakeState, MessageBuffer, NoiseSession } from '../@types/handshake.js' +import type { KeyPair } from '../@types/libp2p.js' + +export class XX extends AbstractHandshake { + private initializeInitiator (prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { + const name = 'Noise_XX_25519_ChaChaPoly_SHA256' + const ss = this.initializeSymmetric(name) + this.mixHash(ss, prologue) + const re = new Uint8Array(32) + + return { ss, s, rs, psk, re } + } + + private initializeResponder (prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { + const name = 'Noise_XX_25519_ChaChaPoly_SHA256' + const ss = this.initializeSymmetric(name) + this.mixHash(ss, prologue) + const re = new Uint8Array(32) + + return { ss, s, rs, psk, re } + } + + private writeMessageA (hs: HandshakeState, payload: bytes, e?: KeyPair): MessageBuffer { + const ns = new Uint8Array(0) + + if (e !== undefined) { + hs.e = e + } else { + hs.e = this.crypto.generateX25519KeyPair() + } + + const ne = hs.e.publicKey + + this.mixHash(hs.ss, ne) + const ciphertext = this.encryptAndHash(hs.ss, payload) + + return { ne, ns, ciphertext } + } + + private writeMessageB (hs: HandshakeState, payload: bytes): MessageBuffer { + hs.e = this.crypto.generateX25519KeyPair() + const ne = hs.e.publicKey + this.mixHash(hs.ss, ne) + + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)) + const spk = hs.s.publicKey + const ns = this.encryptAndHash(hs.ss, spk) + + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)) + const ciphertext = this.encryptAndHash(hs.ss, payload) + + return { ne, ns, ciphertext } + } + + private writeMessageC (hs: HandshakeState, payload: bytes): { messageBuffer: MessageBuffer, cs1: CipherState, cs2: CipherState, h: bytes } { + const spk = hs.s.publicKey + const ns = this.encryptAndHash(hs.ss, spk) + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)) + const ciphertext = this.encryptAndHash(hs.ss, payload) + const ne = this.createEmptyKey() + const messageBuffer: MessageBuffer = { ne, ns, ciphertext } + const { cs1, cs2 } = this.split(hs.ss) + + return { h: hs.ss.h, messageBuffer, cs1, cs2 } + } + + private readMessageA (hs: HandshakeState, message: MessageBuffer): DecryptedResult { + if (isValidPublicKey(message.ne)) { + hs.re = message.ne + } + + this.mixHash(hs.ss, hs.re) + return this.decryptAndHash(hs.ss, message.ciphertext) + } + + private readMessageB (hs: HandshakeState, message: MessageBuffer): DecryptedResult { + if (isValidPublicKey(message.ne)) { + hs.re = message.ne + } + + this.mixHash(hs.ss, hs.re) + if (!hs.e) { + throw new Error('Handshake state `e` param is missing.') + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)) + const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) + if (valid1 && isValidPublicKey(ns)) { + hs.rs = ns + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)) + const { plaintext, valid: valid2 } = this.decryptAndHash(hs.ss, message.ciphertext) + return { plaintext, valid: (valid1 && valid2) } + } + + private readMessageC (hs: HandshakeState, message: MessageBuffer): { h: bytes, plaintext: bytes, valid: boolean, cs1: CipherState, cs2: CipherState } { + const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) + if (valid1 && isValidPublicKey(ns)) { + hs.rs = ns + } + if (!hs.e) { + throw new Error('Handshake state `e` param is missing.') + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)) + + const { plaintext, valid: valid2 } = this.decryptAndHash(hs.ss, message.ciphertext) + const { cs1, cs2 } = this.split(hs.ss) + + return { h: hs.ss.h, plaintext, valid: (valid1 && valid2), cs1, cs2 } + } + + public initSession (initiator: boolean, prologue: bytes32, s: KeyPair): NoiseSession { + const psk = this.createEmptyKey() + const rs = new Uint8Array(32) // no static key yet + let hs + + if (initiator) { + hs = this.initializeInitiator(prologue, s, rs, psk) + } else { + hs = this.initializeResponder(prologue, s, rs, psk) + } + + return { + hs, + i: initiator, + mc: 0 + } + } + + public sendMessage (session: NoiseSession, message: bytes, ephemeral?: KeyPair): MessageBuffer { + let messageBuffer: MessageBuffer + if (session.mc === 0) { + messageBuffer = this.writeMessageA(session.hs, message, ephemeral) + } else if (session.mc === 1) { + messageBuffer = this.writeMessageB(session.hs, message) + } else if (session.mc === 2) { + const { h, messageBuffer: resultingBuffer, cs1, cs2 } = this.writeMessageC(session.hs, message) + messageBuffer = resultingBuffer + session.h = h + session.cs1 = cs1 + session.cs2 = cs2 + } else if (session.mc > 2) { + if (session.i) { + if (!session.cs1) { + throw new Error('CS1 (cipher state) is not defined') + } + + messageBuffer = this.writeMessageRegular(session.cs1, message) + } else { + if (!session.cs2) { + throw new Error('CS2 (cipher state) is not defined') + } + + messageBuffer = this.writeMessageRegular(session.cs2, message) + } + } else { + throw new Error('Session invalid.') + } + + session.mc++ + return messageBuffer + } + + public recvMessage (session: NoiseSession, message: MessageBuffer): DecryptedResult { + let plaintext: bytes = new Uint8Array(0) + let valid = false + if (session.mc === 0) { + ({ plaintext, valid } = this.readMessageA(session.hs, message)) + } else if (session.mc === 1) { + ({ plaintext, valid } = this.readMessageB(session.hs, message)) + } else if (session.mc === 2) { + const { h, plaintext: resultingPlaintext, valid: resultingValid, cs1, cs2 } = this.readMessageC(session.hs, message) + plaintext = resultingPlaintext + valid = resultingValid + session.h = h + session.cs1 = cs1 + session.cs2 = cs2 + } + session.mc++ + return { plaintext, valid } + } +} diff --git a/packages/connection-encryption-noise/src/index.ts b/packages/connection-encryption-noise/src/index.ts new file mode 100644 index 0000000000..3a42c89722 --- /dev/null +++ b/packages/connection-encryption-noise/src/index.ts @@ -0,0 +1,10 @@ +import { Noise } from './noise.js' +import type { NoiseInit } from './noise.js' +import type { NoiseExtensions } from './proto/payload.js' +import type { ConnectionEncrypter } from '@libp2p/interface/connection-encrypter' +export type { ICryptoInterface } from './crypto.js' +export { pureJsCrypto } from './crypto/js.js' + +export function noise (init: NoiseInit = {}): () => ConnectionEncrypter { + return () => new Noise(init) +} diff --git a/packages/connection-encryption-noise/src/logger.ts b/packages/connection-encryption-noise/src/logger.ts new file mode 100644 index 0000000000..b44ca7b420 --- /dev/null +++ b/packages/connection-encryption-noise/src/logger.ts @@ -0,0 +1,51 @@ +import { type Logger, logger } from '@libp2p/logger' +import { toString as uint8ArrayToString } from 'uint8arrays/to-string' +import { DUMP_SESSION_KEYS } from './constants.js' +import type { NoiseSession } from './@types/handshake.js' +import type { KeyPair } from './@types/libp2p.js' + +const log = logger('libp2p:noise') + +export { log as logger } + +let keyLogger: Logger +if (DUMP_SESSION_KEYS) { + keyLogger = log +} else { + keyLogger = Object.assign(() => { /* do nothing */ }, { + enabled: false, + trace: () => {}, + error: () => {} + }) +} + +export function logLocalStaticKeys (s: KeyPair): void { + keyLogger(`LOCAL_STATIC_PUBLIC_KEY ${uint8ArrayToString(s.publicKey, 'hex')}`) + keyLogger(`LOCAL_STATIC_PRIVATE_KEY ${uint8ArrayToString(s.privateKey, 'hex')}`) +} + +export function logLocalEphemeralKeys (e: KeyPair | undefined): void { + if (e) { + keyLogger(`LOCAL_PUBLIC_EPHEMERAL_KEY ${uint8ArrayToString(e.publicKey, 'hex')}`) + keyLogger(`LOCAL_PRIVATE_EPHEMERAL_KEY ${uint8ArrayToString(e.privateKey, 'hex')}`) + } else { + keyLogger('Missing local ephemeral keys.') + } +} + +export function logRemoteStaticKey (rs: Uint8Array): void { + keyLogger(`REMOTE_STATIC_PUBLIC_KEY ${uint8ArrayToString(rs, 'hex')}`) +} + +export function logRemoteEphemeralKey (re: Uint8Array): void { + keyLogger(`REMOTE_EPHEMERAL_PUBLIC_KEY ${uint8ArrayToString(re, 'hex')}`) +} + +export function logCipherState (session: NoiseSession): void { + if (session.cs1 && session.cs2) { + keyLogger(`CIPHER_STATE_1 ${session.cs1.n.getUint64()} ${uint8ArrayToString(session.cs1.k, 'hex')}`) + keyLogger(`CIPHER_STATE_2 ${session.cs2.n.getUint64()} ${uint8ArrayToString(session.cs2.k, 'hex')}`) + } else { + keyLogger('Missing cipher state.') + } +} diff --git a/packages/connection-encryption-noise/src/metrics.ts b/packages/connection-encryption-noise/src/metrics.ts new file mode 100644 index 0000000000..8d0b3a4e70 --- /dev/null +++ b/packages/connection-encryption-noise/src/metrics.ts @@ -0,0 +1,32 @@ +import type { Counter, Metrics } from '@libp2p/interface/metrics' + +export type MetricsRegistry = Record + +export function registerMetrics (metrics: Metrics): MetricsRegistry { + return { + xxHandshakeSuccesses: metrics.registerCounter( + 'libp2p_noise_xxhandshake_successes_total', { + help: 'Total count of noise xxHandshakes successes_' + }), + + xxHandshakeErrors: metrics.registerCounter( + 'libp2p_noise_xxhandshake_error_total', { + help: 'Total count of noise xxHandshakes errors' + }), + + encryptedPackets: metrics.registerCounter( + 'libp2p_noise_encrypted_packets_total', { + help: 'Total count of noise encrypted packets successfully' + }), + + decryptedPackets: metrics.registerCounter( + 'libp2p_noise_decrypted_packets_total', { + help: 'Total count of noise decrypted packets' + }), + + decryptErrors: metrics.registerCounter( + 'libp2p_noise_decrypt_errors_total', { + help: 'Total count of noise decrypt errors' + }) + } +} diff --git a/packages/connection-encryption-noise/src/noise.ts b/packages/connection-encryption-noise/src/noise.ts new file mode 100644 index 0000000000..2277a751ec --- /dev/null +++ b/packages/connection-encryption-noise/src/noise.ts @@ -0,0 +1,192 @@ +import { decode } from 'it-length-prefixed' +import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream' +import { duplexPair } from 'it-pair/duplex' +import { pipe } from 'it-pipe' +import { NOISE_MSG_MAX_LENGTH_BYTES } from './constants.js' +import { pureJsCrypto } from './crypto/js.js' +import { decryptStream, encryptStream } from './crypto/streaming.js' +import { uint16BEDecode, uint16BEEncode } from './encoder.js' +import { XXHandshake } from './handshake-xx.js' +import { type MetricsRegistry, registerMetrics } from './metrics.js' +import { getPayload } from './utils.js' +import type { bytes } from './@types/basic.js' +import type { IHandshake } from './@types/handshake-interface.js' +import type { INoiseConnection, KeyPair } from './@types/libp2p.js' +import type { ICryptoInterface } from './crypto.js' +import type { NoiseExtensions } from './proto/payload.js' +import type { SecuredConnection } from '@libp2p/interface/connection-encrypter' +import type { Metrics } from '@libp2p/interface/metrics' +import type { PeerId } from '@libp2p/interface/peer-id' +import type { Duplex, Source } from 'it-stream-types' + +interface HandshakeParams { + connection: LengthPrefixedStream + isInitiator: boolean + localPeer: PeerId + remotePeer?: PeerId +} + +export interface NoiseInit { + /** + * x25519 private key, reuse for faster handshakes + */ + staticNoiseKey?: bytes + extensions?: NoiseExtensions + crypto?: ICryptoInterface + prologueBytes?: Uint8Array + metrics?: Metrics +} + +export class Noise implements INoiseConnection { + public protocol = '/noise' + public crypto: ICryptoInterface + + private readonly prologue: Uint8Array + private readonly staticKeys: KeyPair + private readonly extensions?: NoiseExtensions + private readonly metrics?: MetricsRegistry + + constructor (init: NoiseInit = {}) { + const { staticNoiseKey, extensions, crypto, prologueBytes, metrics } = init + + this.crypto = crypto ?? pureJsCrypto + this.extensions = extensions + this.metrics = metrics ? registerMetrics(metrics) : undefined + + if (staticNoiseKey) { + // accepts x25519 private key of length 32 + this.staticKeys = this.crypto.generateX25519KeyPairFromSeed(staticNoiseKey) + } else { + this.staticKeys = this.crypto.generateX25519KeyPair() + } + this.prologue = prologueBytes ?? new Uint8Array(0) + } + + /** + * Encrypt outgoing data to the remote party (handshake as initiator) + * + * @param {PeerId} localPeer - PeerId of the receiving peer + * @param {Duplex, AsyncIterable, Promise>} connection - streaming iterable duplex that will be encrypted + * @param {PeerId} remotePeer - PeerId of the remote peer. Used to validate the integrity of the remote peer. + * @returns {Promise} + */ + public async secureOutbound (localPeer: PeerId, connection: Duplex, AsyncIterable, Promise>, remotePeer?: PeerId): Promise> { + const wrappedConnection = lpStream( + connection, + { + lengthEncoder: uint16BEEncode, + lengthDecoder: uint16BEDecode, + maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES + } + ) + const handshake = await this.performHandshake({ + connection: wrappedConnection, + isInitiator: true, + localPeer, + remotePeer + }) + const conn = await this.createSecureConnection(wrappedConnection, handshake) + + return { + conn, + remoteExtensions: handshake.remoteExtensions, + remotePeer: handshake.remotePeer + } + } + + /** + * Decrypt incoming data (handshake as responder). + * + * @param {PeerId} localPeer - PeerId of the receiving peer. + * @param {Duplex, AsyncIterable, Promise>} connection - streaming iterable duplex that will be encryption. + * @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. + * @returns {Promise} + */ + public async secureInbound (localPeer: PeerId, connection: Duplex, AsyncIterable, Promise>, remotePeer?: PeerId): Promise> { + const wrappedConnection = lpStream( + connection, + { + lengthEncoder: uint16BEEncode, + lengthDecoder: uint16BEDecode, + maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES + } + ) + const handshake = await this.performHandshake({ + connection: wrappedConnection, + isInitiator: false, + localPeer, + remotePeer + }) + const conn = await this.createSecureConnection(wrappedConnection, handshake) + + return { + conn, + remotePeer: handshake.remotePeer, + remoteExtensions: handshake.remoteExtensions + } + } + + /** + * If Noise pipes supported, tries IK handshake first with XX as fallback if it fails. + * If noise pipes disabled or remote peer static key is unknown, use XX. + * + * @param {HandshakeParams} params + */ + private async performHandshake (params: HandshakeParams): Promise { + const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.extensions) + + // run XX handshake + return this.performXXHandshake(params, payload) + } + + private async performXXHandshake ( + params: HandshakeParams, + payload: bytes + ): Promise { + const { isInitiator, remotePeer, connection } = params + const handshake = new XXHandshake( + isInitiator, + payload, + this.prologue, + this.crypto, + this.staticKeys, + connection, + remotePeer + ) + + try { + await handshake.propose() + await handshake.exchange() + await handshake.finish() + this.metrics?.xxHandshakeSuccesses.increment() + } catch (e: unknown) { + this.metrics?.xxHandshakeErrors.increment() + if (e instanceof Error) { + e.message = `Error occurred during XX handshake: ${e.message}` + throw e + } + } + + return handshake + } + + private async createSecureConnection ( + connection: LengthPrefixedStream, AsyncIterable, Promise>>, + handshake: IHandshake + ): Promise, Source, Promise>> { + // Create encryption box/unbox wrapper + const [secure, user] = duplexPair() + const network = connection.unwrap() + + await pipe( + secure, // write to wrapper + encryptStream(handshake, this.metrics), // encrypt data + prefix with message length + network, // send to the remote peer + (source) => decode(source, { lengthDecoder: uint16BEDecode }), // read message length prefix + decryptStream(handshake, this.metrics), // decrypt the incoming data + secure // pipe to the wrapper + ) + + return user + } +} diff --git a/packages/connection-encryption-noise/src/nonce.ts b/packages/connection-encryption-noise/src/nonce.ts new file mode 100644 index 0000000000..fab31ace7a --- /dev/null +++ b/packages/connection-encryption-noise/src/nonce.ts @@ -0,0 +1,49 @@ +import type { bytes, uint64 } from './@types/basic.js' + +export const MIN_NONCE = 0 +// For performance reasons, the nonce is represented as a JS `number` +// Although JS `number` can safely represent integers up to 2 ** 53 - 1, we choose to only use +// 4 bytes to store the data for performance reason. +// This is a slight deviation from the noise spec, which describes the max nonce as 2 ** 64 - 2 +// The effect is that this implementation will need a new handshake to be performed after fewer messages are exchanged than other implementations with full uint64 nonces. +// this MAX_NONCE is still a large number of messages, so the practical effect of this is negligible. +export const MAX_NONCE = 0xffffffff + +const ERR_MAX_NONCE = 'Cipherstate has reached maximum n, a new handshake must be performed' + +/** + * The nonce is an uint that's increased over time. + * Maintaining different representations help improve performance. + */ +export class Nonce { + private n: uint64 + private readonly bytes: bytes + private readonly view: DataView + + constructor (n = MIN_NONCE) { + this.n = n + this.bytes = new Uint8Array(12) + this.view = new DataView(this.bytes.buffer, this.bytes.byteOffset, this.bytes.byteLength) + this.view.setUint32(4, n, true) + } + + increment (): void { + this.n++ + // Even though we're treating the nonce as 8 bytes, RFC7539 specifies 12 bytes for a nonce. + this.view.setUint32(4, this.n, true) + } + + getBytes (): bytes { + return this.bytes + } + + getUint64 (): uint64 { + return this.n + } + + assertValue (): void { + if (this.n > MAX_NONCE) { + throw new Error(ERR_MAX_NONCE) + } + } +} diff --git a/packages/connection-encryption-noise/src/proto/payload.proto b/packages/connection-encryption-noise/src/proto/payload.proto new file mode 100644 index 0000000000..cdb2383cb0 --- /dev/null +++ b/packages/connection-encryption-noise/src/proto/payload.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +message NoiseExtensions { + repeated bytes webtransport_certhashes = 1; +} + +message NoiseHandshakePayload { + bytes identity_key = 1; + bytes identity_sig = 2; + optional NoiseExtensions extensions = 4; +} \ No newline at end of file diff --git a/packages/connection-encryption-noise/src/proto/payload.ts b/packages/connection-encryption-noise/src/proto/payload.ts new file mode 100644 index 0000000000..50acd62e39 --- /dev/null +++ b/packages/connection-encryption-noise/src/proto/payload.ts @@ -0,0 +1,148 @@ +/* eslint-disable import/export */ +/* eslint-disable complexity */ +/* eslint-disable @typescript-eslint/no-namespace */ +/* eslint-disable @typescript-eslint/no-unnecessary-boolean-literal-compare */ +/* eslint-disable @typescript-eslint/no-empty-interface */ + +import { encodeMessage, decodeMessage, message } from 'protons-runtime' +import type { Codec } from 'protons-runtime' +import type { Uint8ArrayList } from 'uint8arraylist' + +export interface NoiseExtensions { + webtransportCerthashes: Uint8Array[] +} + +export namespace NoiseExtensions { + let _codec: Codec + + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, w, opts = {}) => { + if (opts.lengthDelimited !== false) { + w.fork() + } + + if (obj.webtransportCerthashes != null) { + for (const value of obj.webtransportCerthashes) { + w.uint32(10) + w.bytes(value) + } + } + + if (opts.lengthDelimited !== false) { + w.ldelim() + } + }, (reader, length) => { + const obj: any = { + webtransportCerthashes: [] + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.webtransportCerthashes.push(reader.bytes()) + break + default: + reader.skipType(tag & 7) + break + } + } + + return obj + }) + } + + return _codec + } + + export const encode = (obj: Partial): Uint8Array => { + return encodeMessage(obj, NoiseExtensions.codec()) + } + + export const decode = (buf: Uint8Array | Uint8ArrayList): NoiseExtensions => { + return decodeMessage(buf, NoiseExtensions.codec()) + } +} + +export interface NoiseHandshakePayload { + identityKey: Uint8Array + identitySig: Uint8Array + extensions?: NoiseExtensions +} + +export namespace NoiseHandshakePayload { + let _codec: Codec + + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, w, opts = {}) => { + if (opts.lengthDelimited !== false) { + w.fork() + } + + if (opts.writeDefaults === true || (obj.identityKey != null && obj.identityKey.byteLength > 0)) { + w.uint32(10) + w.bytes(obj.identityKey ?? new Uint8Array(0)) + } + + if (opts.writeDefaults === true || (obj.identitySig != null && obj.identitySig.byteLength > 0)) { + w.uint32(18) + w.bytes(obj.identitySig ?? new Uint8Array(0)) + } + + if (obj.extensions != null) { + w.uint32(34) + NoiseExtensions.codec().encode(obj.extensions, w, { + writeDefaults: false + }) + } + + if (opts.lengthDelimited !== false) { + w.ldelim() + } + }, (reader, length) => { + const obj: any = { + identityKey: new Uint8Array(0), + identitySig: new Uint8Array(0) + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.identityKey = reader.bytes() + break + case 2: + obj.identitySig = reader.bytes() + break + case 4: + obj.extensions = NoiseExtensions.codec().decode(reader, reader.uint32()) + break + default: + reader.skipType(tag & 7) + break + } + } + + return obj + }) + } + + return _codec + } + + export const encode = (obj: Partial): Uint8Array => { + return encodeMessage(obj, NoiseHandshakePayload.codec()) + } + + export const decode = (buf: Uint8Array | Uint8ArrayList): NoiseHandshakePayload => { + return decodeMessage(buf, NoiseHandshakePayload.codec()) + } +} diff --git a/packages/connection-encryption-noise/src/utils.ts b/packages/connection-encryption-noise/src/utils.ts new file mode 100644 index 0000000000..993c9628c3 --- /dev/null +++ b/packages/connection-encryption-noise/src/utils.ts @@ -0,0 +1,111 @@ +import { unmarshalPublicKey, unmarshalPrivateKey } from '@libp2p/crypto/keys' +import { peerIdFromKeys } from '@libp2p/peer-id' +import { concat as uint8ArrayConcat } from 'uint8arrays/concat' +import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { type NoiseExtensions, NoiseHandshakePayload } from './proto/payload.js' +import type { bytes } from './@types/basic.js' +import type { PeerId } from '@libp2p/interface/peer-id' + +export async function getPayload ( + localPeer: PeerId, + staticPublicKey: bytes, + extensions?: NoiseExtensions +): Promise { + const signedPayload = await signPayload(localPeer, getHandshakePayload(staticPublicKey)) + + if (localPeer.publicKey == null) { + throw new Error('PublicKey was missing from local PeerId') + } + + return createHandshakePayload( + localPeer.publicKey, + signedPayload, + extensions + ) +} + +export function createHandshakePayload ( + libp2pPublicKey: Uint8Array, + signedPayload: Uint8Array, + extensions?: NoiseExtensions +): bytes { + return NoiseHandshakePayload.encode({ + identityKey: libp2pPublicKey, + identitySig: signedPayload, + extensions: extensions ?? { webtransportCerthashes: [] } + }).subarray() +} + +export async function signPayload (peerId: PeerId, payload: bytes): Promise { + if (peerId.privateKey == null) { + throw new Error('PrivateKey was missing from PeerId') + } + + const privateKey = await unmarshalPrivateKey(peerId.privateKey) + + return privateKey.sign(payload) +} + +export async function getPeerIdFromPayload (payload: NoiseHandshakePayload): Promise { + return peerIdFromKeys(payload.identityKey) +} + +export function decodePayload (payload: bytes | Uint8Array): NoiseHandshakePayload { + return NoiseHandshakePayload.decode(payload) +} + +export function getHandshakePayload (publicKey: bytes): bytes { + const prefix = uint8ArrayFromString('noise-libp2p-static-key:') + return uint8ArrayConcat([prefix, publicKey], prefix.length + publicKey.length) +} + +/** + * Verifies signed payload, throws on any irregularities. + * + * @param {bytes} noiseStaticKey - owner's noise static key + * @param {bytes} payload - decoded payload + * @param {PeerId} remotePeer - owner's libp2p peer ID + * @returns {Promise} - peer ID of payload owner + */ +export async function verifySignedPayload ( + noiseStaticKey: bytes, + payload: NoiseHandshakePayload, + remotePeer: PeerId +): Promise { + // Unmarshaling from PublicKey protobuf + const payloadPeerId = await peerIdFromKeys(payload.identityKey) + if (!payloadPeerId.equals(remotePeer)) { + throw new Error(`Payload identity key ${payloadPeerId.toString()} does not match expected remote peer ${remotePeer.toString()}`) + } + const generatedPayload = getHandshakePayload(noiseStaticKey) + + if (payloadPeerId.publicKey == null) { + throw new Error('PublicKey was missing from PeerId') + } + + if (payload.identitySig == null) { + throw new Error('Signature was missing from message') + } + + const publicKey = unmarshalPublicKey(payloadPeerId.publicKey) + + const valid = await publicKey.verify(generatedPayload, payload.identitySig) + + if (!valid) { + throw new Error("Static key doesn't match to peer that signed payload!") + } + + return payloadPeerId +} + +export function isValidPublicKey (pk: bytes): boolean { + if (!(pk instanceof Uint8Array)) { + return false + } + + if (pk.length !== 32) { + return false + } + + return true +} diff --git a/packages/connection-encryption-noise/test/compliance.spec.ts b/packages/connection-encryption-noise/test/compliance.spec.ts new file mode 100644 index 0000000000..e97dc0ad7f --- /dev/null +++ b/packages/connection-encryption-noise/test/compliance.spec.ts @@ -0,0 +1,11 @@ +import tests from '@libp2p/interface-compliance-tests/connection-encryption' +import { Noise } from '../src/noise.js' + +describe('spec compliance tests', function () { + tests({ + async setup () { + return new Noise() + }, + async teardown () {} + }) +}) diff --git a/packages/connection-encryption-noise/test/fixtures/peer.ts b/packages/connection-encryption-noise/test/fixtures/peer.ts new file mode 100644 index 0000000000..753d9d6927 --- /dev/null +++ b/packages/connection-encryption-noise/test/fixtures/peer.ts @@ -0,0 +1,37 @@ +import { createEd25519PeerId, createFromJSON } from '@libp2p/peer-id-factory' +import type { PeerId } from '@libp2p/interface/peer-id' + +// ed25519 keys +const peers = [{ + id: '12D3KooWH45PiqBjfnEfDfCD6TqJrpqTBJvQDwGHvjGpaWwms46D', + privKey: 'CAESYBtKXrMwawAARmLScynQUuSwi/gGSkwqDPxi15N3dqDHa4T4iWupkMe5oYGwGH3Hyfvd/QcgSTqg71oYZJadJ6prhPiJa6mQx7mhgbAYfcfJ+939ByBJOqDvWhhklp0nqg==', + pubKey: 'CAESIGuE+IlrqZDHuaGBsBh9x8n73f0HIEk6oO9aGGSWnSeq' +}, { + id: '12D3KooWP63uzL78BRMpkQ7augMdNi1h3VBrVWZucKjyhzGVaSi1', + privKey: 'CAESYPxO3SHyfc2578hDmfkGGBY255JjiLuVavJWy+9ivlpsxSyVKf36ipyRGL6szGzHuFs5ceEuuGVrPMg/rW2Ch1bFLJUp/fqKnJEYvqzMbMe4Wzlx4S64ZWs8yD+tbYKHVg==', + pubKey: 'CAESIMUslSn9+oqckRi+rMxsx7hbOXHhLrhlazzIP61tgodW' +}, { + id: '12D3KooWF85R7CM2Wikdtb2sjwnd24e1tgojf3MEWwizmVB8PA6U', + privKey: 'CAESYNXoQ5CnooE939AEqE2JJGPqvhoFJn0xP+j9KwjfOfDkTtPyfn2kJ1gn3uOYTcmoHFU1bbETNtRVuPMi1fmDmqFO0/J+faQnWCfe45hNyagcVTVtsRM21FW48yLV+YOaoQ==', + pubKey: 'CAESIE7T8n59pCdYJ97jmE3JqBxVNW2xEzbUVbjzItX5g5qh' +}, { + id: '12D3KooWPCofiCjhdtezP4eMnqBjjutFZNHjV39F5LWNrCvaLnzT', + privKey: 'CAESYLhUut01XPu+yIPbtZ3WnxOd26FYuTMRn/BbdFYsZE2KxueKRlo9yIAxmFReoNFUKztUU4G2aUiTbqDQaA6i0MDG54pGWj3IgDGYVF6g0VQrO1RTgbZpSJNuoNBoDqLQwA==', + pubKey: 'CAESIMbnikZaPciAMZhUXqDRVCs7VFOBtmlIk26g0GgOotDA' +}] + +export async function createPeerIdsFromFixtures (length: number): Promise { + return Promise.all( + Array.from({ length }).map(async (_, i) => createFromJSON(peers[i])) + ) +} + +export async function createPeerIds (length: number): Promise { + const peerIds: PeerId[] = [] + for (let i = 0; i < length; i++) { + const id = await createEd25519PeerId() + peerIds.push(id) + } + + return peerIds +} diff --git a/packages/connection-encryption-noise/test/handshakes/xx.spec.ts b/packages/connection-encryption-noise/test/handshakes/xx.spec.ts new file mode 100644 index 0000000000..69e2733a73 --- /dev/null +++ b/packages/connection-encryption-noise/test/handshakes/xx.spec.ts @@ -0,0 +1,160 @@ +import { Buffer } from 'buffer' +import { expect, assert } from 'aegir/chai' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { toString as uint8ArrayToString } from 'uint8arrays/to-string' +import { pureJsCrypto } from '../../src/crypto/js.js' +import { XX } from '../../src/handshakes/xx.js' +import { createHandshakePayload, getHandshakePayload } from '../../src/utils.js' +import { generateEd25519Keys } from '../utils.js' +import type { NoiseSession } from '../../src/@types/handshake.js' +import type { KeyPair } from '../../src/@types/libp2p.js' + +describe('XX Handshake', () => { + const prologue = Buffer.alloc(0) + + it('Test creating new XX session', async () => { + try { + const xx = new XX(pureJsCrypto) + + const kpInitiator: KeyPair = pureJsCrypto.generateX25519KeyPair() + + xx.initSession(true, prologue, kpInitiator) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('Test get HKDF', () => { + const ckBytes = Buffer.from('4e6f6973655f58585f32353531395f58436861436861506f6c795f53484132353600000000000000000000000000000000000000000000000000000000000000', 'hex') + const ikm = Buffer.from('a3eae50ea37a47e8a7aa0c7cd8e16528670536dcd538cebfd724fb68ce44f1910ad898860666227d4e8dd50d22a9a64d1c0a6f47ace092510161e9e442953da3', 'hex') + const ck = Buffer.alloc(32) + ckBytes.copy(ck) + + const [k1, k2, k3] = pureJsCrypto.getHKDF(ck, ikm) + expect(uint8ArrayToString(k1, 'hex')).to.equal('cc5659adff12714982f806e2477a8d5ddd071def4c29bb38777b7e37046f6914') + expect(uint8ArrayToString(k2, 'hex')).to.equal('a16ada915e551ab623f38be674bb4ef15d428ae9d80688899c9ef9b62ef208fa') + expect(uint8ArrayToString(k3, 'hex')).to.equal('ff67bf9727e31b06efc203907e6786667d2c7a74ac412b4d31a80ba3fd766f68') + }) + + async function doHandshake (xx: XX): Promise<{ nsInit: NoiseSession, nsResp: NoiseSession }> { + const kpInit = pureJsCrypto.generateX25519KeyPair() + const kpResp = pureJsCrypto.generateX25519KeyPair() + + // initiator setup + const libp2pInitKeys = await generateEd25519Keys() + const initSignedPayload = await libp2pInitKeys.sign(getHandshakePayload(kpInit.publicKey)) + + // responder setup + const libp2pRespKeys = await generateEd25519Keys() + const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResp.publicKey)) + + // initiator: new XX noise session + const nsInit = xx.initSession(true, prologue, kpInit) + // responder: new XX noise session + const nsResp = xx.initSession(false, prologue, kpResp) + + /* STAGE 0 */ + + // initiator creates payload + libp2pInitKeys.marshal().slice(0, 32) + const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64) + + const payloadInitEnc = createHandshakePayload(libp2pInitPubKey, initSignedPayload) + + // initiator sends message + const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]) + const messageBuffer = xx.sendMessage(nsInit, message) + + expect(messageBuffer.ne.length).not.equal(0) + + // responder receives message + xx.recvMessage(nsResp, messageBuffer) + + /* STAGE 1 */ + + // responder creates payload + libp2pRespKeys.marshal().slice(0, 32) + const libp2pRespPubKey = libp2pRespKeys.marshal().slice(32, 64) + const payloadRespEnc = createHandshakePayload(libp2pRespPubKey, respSignedPayload) + + const message1 = Buffer.concat([message, payloadRespEnc]) + const messageBuffer2 = xx.sendMessage(nsResp, message1) + + expect(messageBuffer2.ne.length).not.equal(0) + expect(messageBuffer2.ns.length).not.equal(0) + + // initiator receive payload + xx.recvMessage(nsInit, messageBuffer2) + + /* STAGE 2 */ + + // initiator send message + const messageBuffer3 = xx.sendMessage(nsInit, Buffer.alloc(0)) + + // responder receive message + xx.recvMessage(nsResp, messageBuffer3) + + if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { + throw new Error('CipherState missing') + } + + assert(uint8ArrayEquals(nsInit.cs1.k, nsResp.cs1.k)) + assert(uint8ArrayEquals(nsInit.cs2.k, nsResp.cs2.k)) + + return { nsInit, nsResp } + } + + it('Test handshake', async () => { + try { + const xx = new XX(pureJsCrypto) + await doHandshake(xx) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('Test symmetric encrypt and decrypt', async () => { + try { + const xx = new XX(pureJsCrypto) + const { nsInit, nsResp } = await doHandshake(xx) + const ad = Buffer.from('authenticated') + const message = Buffer.from('HelloCrypto') + + if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { + throw new Error('CipherState missing') + } + + const ciphertext = xx.encryptWithAd(nsInit.cs1, ad, message) + assert(!uint8ArrayEquals(Buffer.from('HelloCrypto'), ciphertext), 'Encrypted message should not be same as plaintext.') + const { plaintext: decrypted, valid } = xx.decryptWithAd(nsResp.cs1, ad, ciphertext) + + assert(uint8ArrayEquals(Buffer.from('HelloCrypto'), decrypted), 'Decrypted text not equal to original message.') + assert(valid) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('Test multiple messages encryption and decryption', async () => { + const xx = new XX(pureJsCrypto) + const { nsInit, nsResp } = await doHandshake(xx) + const ad = Buffer.from('authenticated') + const message = Buffer.from('ethereum1') + + if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { + throw new Error('CipherState missing') + } + + const encrypted = xx.encryptWithAd(nsInit.cs1, ad, message) + const { plaintext: decrypted } = xx.decryptWithAd(nsResp.cs1, ad, encrypted) + assert.equal('ethereum1', uint8ArrayToString(decrypted, 'utf8'), 'Decrypted text not equal to original message.') + + const message2 = Buffer.from('ethereum2') + const encrypted2 = xx.encryptWithAd(nsInit.cs1, ad, message2) + const { plaintext: decrypted2 } = xx.decryptWithAd(nsResp.cs1, ad, encrypted2) + assert.equal('ethereum2', uint8ArrayToString(decrypted2, 'utf-8'), 'Decrypted text not equal to original message.') + }) +}) diff --git a/packages/connection-encryption-noise/test/index.spec.ts b/packages/connection-encryption-noise/test/index.spec.ts new file mode 100644 index 0000000000..5ce901d262 --- /dev/null +++ b/packages/connection-encryption-noise/test/index.spec.ts @@ -0,0 +1,54 @@ +import { expect } from 'aegir/chai' +import { lpStream } from 'it-length-prefixed-stream' +import { duplexPair } from 'it-pair/duplex' +import sinon from 'sinon' +import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { noise } from '../src/index.js' +import { Noise } from '../src/noise.js' +import { createPeerIdsFromFixtures } from './fixtures/peer.js' +import type { Metrics } from '@libp2p/interface/metrics' + +function createCounterSpy (): ReturnType { + return sinon.spy({ + increment: () => {}, + reset: () => {} + }) +} + +describe('Index', () => { + it('should expose class with tag and required functions', () => { + const noiseInstance = noise()() + expect(noiseInstance.protocol).to.equal('/noise') + expect(typeof (noiseInstance.secureInbound)).to.equal('function') + expect(typeof (noiseInstance.secureOutbound)).to.equal('function') + }) + + it('should collect metrics', async () => { + const [localPeer, remotePeer] = await createPeerIdsFromFixtures(2) + const metricsRegistry = new Map>() + const metrics = { + registerCounter: (name: string) => { + const counter = createCounterSpy() + metricsRegistry.set(name, counter) + return counter + } + } + const noiseInit = new Noise({ metrics: metrics as any as Metrics }) + const noiseResp = new Noise({}) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer) + ]) + const wrappedInbound = lpStream(inbound.conn) + const wrappedOutbound = lpStream(outbound.conn) + + await wrappedOutbound.write(uint8ArrayFromString('test')) + await wrappedInbound.read() + expect(metricsRegistry.get('libp2p_noise_xxhandshake_successes_total')?.increment.callCount).to.equal(1) + expect(metricsRegistry.get('libp2p_noise_xxhandshake_error_total')?.increment.callCount).to.equal(0) + expect(metricsRegistry.get('libp2p_noise_encrypted_packets_total')?.increment.callCount).to.equal(1) + expect(metricsRegistry.get('libp2p_noise_decrypt_errors_total')?.increment.callCount).to.equal(0) + }) +}) diff --git a/packages/connection-encryption-noise/test/noise.spec.ts b/packages/connection-encryption-noise/test/noise.spec.ts new file mode 100644 index 0000000000..8b2c6937f4 --- /dev/null +++ b/packages/connection-encryption-noise/test/noise.spec.ts @@ -0,0 +1,210 @@ +import { Buffer } from 'buffer' +import { assert, expect } from 'aegir/chai' +import { randomBytes } from 'iso-random-stream' +import { byteStream } from 'it-byte-stream' +import { lpStream } from 'it-length-prefixed-stream' +import { duplexPair } from 'it-pair/duplex' +import sinon from 'sinon' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { toString as uint8ArrayToString } from 'uint8arrays/to-string' +import { NOISE_MSG_MAX_LENGTH_BYTES } from '../src/constants.js' +import { pureJsCrypto } from '../src/crypto/js.js' +import { decode0, decode2, encode1, uint16BEDecode, uint16BEEncode } from '../src/encoder.js' +import { XXHandshake } from '../src/handshake-xx.js' +import { XX } from '../src/handshakes/xx.js' +import { Noise } from '../src/noise.js' +import { createHandshakePayload, getHandshakePayload, getPayload, signPayload } from '../src/utils.js' +import { createPeerIdsFromFixtures } from './fixtures/peer.js' +import { getKeyPairFromPeerId } from './utils.js' +import type { PeerId } from '@libp2p/interface/peer-id' + +describe('Noise', () => { + let remotePeer: PeerId, localPeer: PeerId + const sandbox = sinon.createSandbox() + + before(async () => { + [localPeer, remotePeer] = await createPeerIdsFromFixtures(2) + }) + + afterEach(function () { + sandbox.restore() + }) + + it('should communicate through encrypted streams without noise pipes', async () => { + try { + const noiseInit = new Noise({ staticNoiseKey: undefined, extensions: undefined }) + const noiseResp = new Noise({ staticNoiseKey: undefined, extensions: undefined }) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer) + ]) + const wrappedInbound = lpStream(inbound.conn) + const wrappedOutbound = lpStream(outbound.conn) + + await wrappedOutbound.write(Buffer.from('test')) + const response = await wrappedInbound.read() + expect(uint8ArrayToString(response.slice())).equal('test') + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('should test that secureOutbound is spec compliant', async () => { + const noiseInit = new Noise({ staticNoiseKey: undefined }) + const [inboundConnection, outboundConnection] = duplexPair() + + const [outbound, { wrapped, handshake }] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + (async () => { + const wrapped = lpStream( + inboundConnection, + { + lengthEncoder: uint16BEEncode, + lengthDecoder: uint16BEDecode, + maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES + } + ) + const prologue = Buffer.alloc(0) + const staticKeys = pureJsCrypto.generateX25519KeyPair() + const xx = new XX(pureJsCrypto) + + const payload = await getPayload(remotePeer, staticKeys.publicKey) + const handshake = new XXHandshake(false, payload, prologue, pureJsCrypto, staticKeys, wrapped, localPeer, xx) + + let receivedMessageBuffer = decode0((await wrapped.read()).slice()) + // The first handshake message contains the initiator's ephemeral public key + expect(receivedMessageBuffer.ne.length).equal(32) + xx.recvMessage(handshake.session, receivedMessageBuffer) + + // Stage 1 + const { publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer) + const signedPayload = await signPayload(remotePeer, getHandshakePayload(staticKeys.publicKey)) + const handshakePayload = createHandshakePayload(libp2pPubKey, signedPayload) + + const messageBuffer = xx.sendMessage(handshake.session, handshakePayload) + await wrapped.write(encode1(messageBuffer)) + + // Stage 2 - finish handshake + receivedMessageBuffer = decode2((await wrapped.read()).slice()) + xx.recvMessage(handshake.session, receivedMessageBuffer) + return { wrapped, handshake } + })() + ]) + + const wrappedOutbound = byteStream(outbound.conn) + await wrappedOutbound.write(uint8ArrayFromString('test')) + + // Check that noise message is prefixed with 16-bit big-endian unsigned integer + const data = (await wrapped.read()).slice() + const { plaintext: decrypted, valid } = handshake.decrypt(data, handshake.session) + // Decrypted data should match + expect(uint8ArrayEquals(decrypted, uint8ArrayFromString('test'))).to.be.true() + expect(valid).to.be.true() + }) + + it('should test large payloads', async function () { + this.timeout(10000) + try { + const noiseInit = new Noise({ staticNoiseKey: undefined }) + const noiseResp = new Noise({ staticNoiseKey: undefined }) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer) + ]) + const wrappedInbound = byteStream(inbound.conn) + const wrappedOutbound = lpStream(outbound.conn) + + const largePlaintext = randomBytes(60000) + await wrappedOutbound.write(Buffer.from(largePlaintext)) + const response = await wrappedInbound.read(60000) + + expect(response.length).equals(largePlaintext.length) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('should working without remote peer provided in incoming connection', async () => { + try { + const staticKeysInitiator = pureJsCrypto.generateX25519KeyPair() + const noiseInit = new Noise({ staticNoiseKey: staticKeysInitiator.privateKey }) + const staticKeysResponder = pureJsCrypto.generateX25519KeyPair() + const noiseResp = new Noise({ staticNoiseKey: staticKeysResponder.privateKey }) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection) + ]) + const wrappedInbound = lpStream(inbound.conn) + const wrappedOutbound = lpStream(outbound.conn) + + await wrappedOutbound.write(Buffer.from('test v2')) + const response = await wrappedInbound.read() + expect(uint8ArrayToString(response.slice())).equal('test v2') + + if (inbound.remotePeer.publicKey == null || localPeer.publicKey == null || + outbound.remotePeer.publicKey == null || remotePeer.publicKey == null) { + throw new Error('Public key missing from PeerId') + } + + assert(uint8ArrayEquals(inbound.remotePeer.publicKey, localPeer.publicKey)) + assert(uint8ArrayEquals(outbound.remotePeer.publicKey, remotePeer.publicKey)) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('should accept and return Noise extension from remote peer', async () => { + try { + const certhashInit = Buffer.from('certhash data from init') + const staticKeysInitiator = pureJsCrypto.generateX25519KeyPair() + const noiseInit = new Noise({ staticNoiseKey: staticKeysInitiator.privateKey, extensions: { webtransportCerthashes: [certhashInit] } }) + const staticKeysResponder = pureJsCrypto.generateX25519KeyPair() + const certhashResp = Buffer.from('certhash data from respon') + const noiseResp = new Noise({ staticNoiseKey: staticKeysResponder.privateKey, extensions: { webtransportCerthashes: [certhashResp] } }) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection) + ]) + + assert(uint8ArrayEquals(inbound.remoteExtensions?.webtransportCerthashes[0] ?? new Uint8Array(), certhashInit)) + assert(uint8ArrayEquals(outbound.remoteExtensions?.webtransportCerthashes[0] ?? new Uint8Array(), certhashResp)) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('should accept a prologue', async () => { + try { + const noiseInit = new Noise({ staticNoiseKey: undefined, crypto: pureJsCrypto, prologueBytes: Buffer.from('Some prologue') }) + const noiseResp = new Noise({ staticNoiseKey: undefined, crypto: pureJsCrypto, prologueBytes: Buffer.from('Some prologue') }) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer) + ]) + const wrappedInbound = lpStream(inbound.conn) + const wrappedOutbound = lpStream(outbound.conn) + + await wrappedOutbound.write(Buffer.from('test')) + const response = await wrappedInbound.read() + expect(uint8ArrayToString(response.slice())).equal('test') + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) +}) diff --git a/packages/connection-encryption-noise/test/utils.ts b/packages/connection-encryption-noise/test/utils.ts new file mode 100644 index 0000000000..4e5cd5dbef --- /dev/null +++ b/packages/connection-encryption-noise/test/utils.ts @@ -0,0 +1,19 @@ +import { keys } from '@libp2p/crypto' +import type { KeyPair } from '../src/@types/libp2p.js' +import type { PrivateKey } from '@libp2p/interface/keys' +import type { PeerId } from '@libp2p/interface/peer-id' + +export async function generateEd25519Keys (): Promise { + return keys.generateKeyPair('Ed25519', 32) +} + +export function getKeyPairFromPeerId (peerId: PeerId): KeyPair { + if (peerId.privateKey == null || peerId.publicKey == null) { + throw new Error('PrivateKey or PublicKey missing from PeerId') + } + + return { + privateKey: peerId.privateKey.subarray(0, 32), + publicKey: peerId.publicKey + } +} diff --git a/packages/connection-encryption-noise/test/xx-handshake.spec.ts b/packages/connection-encryption-noise/test/xx-handshake.spec.ts new file mode 100644 index 0000000000..7223e76536 --- /dev/null +++ b/packages/connection-encryption-noise/test/xx-handshake.spec.ts @@ -0,0 +1,126 @@ +import { Buffer } from 'buffer' +import { assert, expect } from 'aegir/chai' +import { lpStream } from 'it-length-prefixed-stream' +import { duplexPair } from 'it-pair/duplex' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { pureJsCrypto } from '../src/crypto/js.js' +import { XXHandshake } from '../src/handshake-xx.js' +import { getPayload } from '../src/utils.js' +import { createPeerIdsFromFixtures } from './fixtures/peer.js' +import type { PeerId } from '@libp2p/interface/peer-id' + +describe('XX Handshake', () => { + let peerA: PeerId, peerB: PeerId, fakePeer: PeerId + + before(async () => { + [peerA, peerB, fakePeer] = await createPeerIdsFromFixtures(3) + }) + + it('should propose, exchange and finish handshake', async () => { + try { + const duplex = duplexPair() + const connectionFrom = lpStream(duplex[0]) + const connectionTo = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = pureJsCrypto.generateX25519KeyPair() + const staticKeysResponder = pureJsCrypto.generateX25519KeyPair() + + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) + const handshakeInitator = new XXHandshake(true, initPayload, prologue, pureJsCrypto, staticKeysInitiator, connectionFrom, peerB) + + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) + const handshakeResponder = new XXHandshake(false, respPayload, prologue, pureJsCrypto, staticKeysResponder, connectionTo, peerA) + + await handshakeInitator.propose() + await handshakeResponder.propose() + + await handshakeResponder.exchange() + await handshakeInitator.exchange() + + await handshakeInitator.finish() + await handshakeResponder.finish() + + const sessionInitator = handshakeInitator.session + const sessionResponder = handshakeResponder.session + + // Test shared key + if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { + assert(uint8ArrayEquals(sessionInitator.cs1.k, sessionResponder.cs1.k)) + assert(uint8ArrayEquals(sessionInitator.cs2.k, sessionResponder.cs2.k)) + } else { + assert(false) + } + + // Test encryption and decryption + const encrypted = handshakeInitator.encrypt(Buffer.from('encryptthis'), handshakeInitator.session) + const { plaintext: decrypted, valid } = handshakeResponder.decrypt(encrypted, handshakeResponder.session) + assert(uint8ArrayEquals(decrypted, Buffer.from('encryptthis'))) + assert(valid) + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) + + it('Initiator should fail to exchange handshake if given wrong public key in payload', async () => { + try { + const duplex = duplexPair() + const connectionFrom = lpStream(duplex[0]) + const connectionTo = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = pureJsCrypto.generateX25519KeyPair() + const staticKeysResponder = pureJsCrypto.generateX25519KeyPair() + + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) + const handshakeInitator = new XXHandshake(true, initPayload, prologue, pureJsCrypto, staticKeysInitiator, connectionFrom, fakePeer) + + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) + const handshakeResponder = new XXHandshake(false, respPayload, prologue, pureJsCrypto, staticKeysResponder, connectionTo, peerA) + + await handshakeInitator.propose() + await handshakeResponder.propose() + + await handshakeResponder.exchange() + await handshakeInitator.exchange() + + assert(false, 'Should throw exception') + } catch (e) { + const err = e as Error + expect(err.message).equals(`Error occurred while verifying signed payload: Payload identity key ${peerB.toString()} does not match expected remote peer ${fakePeer.toString()}`) + } + }) + + it('Responder should fail to exchange handshake if given wrong public key in payload', async () => { + try { + const duplex = duplexPair() + const connectionFrom = lpStream(duplex[0]) + const connectionTo = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = pureJsCrypto.generateX25519KeyPair() + const staticKeysResponder = pureJsCrypto.generateX25519KeyPair() + + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) + const handshakeInitator = new XXHandshake(true, initPayload, prologue, pureJsCrypto, staticKeysInitiator, connectionFrom, peerB) + + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) + const handshakeResponder = new XXHandshake(false, respPayload, prologue, pureJsCrypto, staticKeysResponder, connectionTo, fakePeer) + + await handshakeInitator.propose() + await handshakeResponder.propose() + + await handshakeResponder.exchange() + await handshakeInitator.exchange() + + await handshakeInitator.finish() + await handshakeResponder.finish() + + assert(false, 'Should throw exception') + } catch (e) { + const err = e as Error + expect(err.message).equals(`Error occurred while verifying signed payload: Payload identity key ${peerA.toString()} does not match expected remote peer ${fakePeer.toString()}`) + } + }) +}) diff --git a/packages/connection-encryption-noise/tsconfig.json b/packages/connection-encryption-noise/tsconfig.json new file mode 100644 index 0000000000..15b6697fa5 --- /dev/null +++ b/packages/connection-encryption-noise/tsconfig.json @@ -0,0 +1,42 @@ +{ + "extends": "aegir/src/config/tsconfig.aegir.json", + "compilerOptions": { + "outDir": "dist" + }, + "include": [ + "src", + "test" + ], + "references": [ + { + "path": "../crypto" + }, + { + "path": "../interface" + }, + { + "path": "../logger" + }, + { + "path": "../peer-id" + }, + { + "path": "../stream-multiplexer-yamux" + }, + { + "path": "../libp2p-daemon-client" + }, + { + "path": "../libp2p-daemon-server" + }, + { + "path": "../interface-compliance-tests" + }, + { + "path": "../peer-id-factory" + }, + { + "path": "../transport-tcp" + } + ] +} diff --git a/packages/interface-compliance-tests/package.json b/packages/interface-compliance-tests/package.json index 1231f7d485..2a8e388300 100644 --- a/packages/interface-compliance-tests/package.json +++ b/packages/interface-compliance-tests/package.json @@ -92,6 +92,7 @@ "lint": "aegir lint", "dep-check": "aegir dep-check", "build": "aegir build", + "generate": "protons src/stream-muxer/fixtures/pb/message.proto", "test": "aegir test", "test:chrome": "aegir test -t browser --cov", "test:chrome-webworker": "aegir test -t webworker", @@ -110,7 +111,6 @@ "@libp2p/peer-id-factory": "^2.0.0", "@multiformats/multiaddr": "^12.1.3", "abortable-iterator": "^5.0.1", - "any-signal": "^4.1.1", "delay": "^6.0.0", "it-all": "^3.0.2", "it-drain": "^3.0.2", @@ -118,21 +118,25 @@ "it-map": "^3.0.3", "it-ndjson": "^1.0.3", "it-pair": "^2.0.6", + "it-protobuf-stream": "^1.0.0", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "it-stream-types": "^2.0.1", + "it-to-buffer": "^4.0.2", "merge-options": "^3.0.4", "p-defer": "^4.0.0", "p-event": "^6.0.0", "p-limit": "^4.0.0", "p-wait-for": "^5.0.2", + "protons-runtime": "^5.0.0", "sinon": "^15.1.2", "ts-sinon": "^2.0.2", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.4" }, "devDependencies": { - "aegir": "^39.0.13" + "aegir": "^39.0.13", + "protons": "^7.0.2" }, "typedoc": { "entryPoint": "./src/index.ts" diff --git a/packages/interface-compliance-tests/src/connection-encryption/utils/index.ts b/packages/interface-compliance-tests/src/connection-encryption/utils/index.ts index 6543eedc32..173c963173 100644 --- a/packages/interface-compliance-tests/src/connection-encryption/utils/index.ts +++ b/packages/interface-compliance-tests/src/connection-encryption/utils/index.ts @@ -10,6 +10,7 @@ export function createMaConnPair (): [MultiaddrConnection, MultiaddrConnection] const output: MultiaddrConnection = { ...duplex, close: async () => {}, + abort: () => {}, remoteAddr: multiaddr('/ip4/127.0.0.1/tcp/4001'), timeline: { open: Date.now() diff --git a/packages/interface-compliance-tests/src/connection/index.ts b/packages/interface-compliance-tests/src/connection/index.ts index afd0213347..833ca30ba1 100644 --- a/packages/interface-compliance-tests/src/connection/index.ts +++ b/packages/interface-compliance-tests/src/connection/index.ts @@ -22,7 +22,7 @@ export default (test: TestSetup): void => { expect(connection.id).to.exist() expect(connection.remotePeer).to.exist() expect(connection.remoteAddr).to.exist() - expect(connection.status).to.equal('OPEN') + expect(connection.status).to.equal('open') expect(connection.timeline.open).to.exist() expect(connection.timeline.close).to.not.exist() expect(connection.direction).to.exist() @@ -31,7 +31,7 @@ export default (test: TestSetup): void => { }) it('should get the metadata of an open connection', () => { - expect(connection.status).to.equal('OPEN') + expect(connection.status).to.equal('open') expect(connection.direction).to.exist() expect(connection.timeline.open).to.exist() expect(connection.timeline.close).to.not.exist() @@ -89,7 +89,7 @@ export default (test: TestSetup): void => { await connection.close() expect(connection.timeline.close).to.exist() - expect(connection.status).to.equal('CLOSED') + expect(connection.status).to.equal('closed') }) it('should be able to close the connection after opening a stream', async () => { @@ -102,7 +102,7 @@ export default (test: TestSetup): void => { await connection.close() expect(connection.timeline.close).to.exist() - expect(connection.status).to.equal('CLOSED') + expect(connection.status).to.equal('closed') }) it('should properly track streams', async () => { @@ -112,7 +112,7 @@ export default (test: TestSetup): void => { expect(stream).to.have.property('protocol', protocol) // Close stream - stream.close() + await stream.close() expect(connection.streams.filter(s => s.id === stream.id)).to.be.empty() }) diff --git a/packages/interface-compliance-tests/src/mocks/connection-manager.ts b/packages/interface-compliance-tests/src/mocks/connection-manager.ts index 26a1ae1600..4f5cecbac0 100644 --- a/packages/interface-compliance-tests/src/mocks/connection-manager.ts +++ b/packages/interface-compliance-tests/src/mocks/connection-manager.ts @@ -4,12 +4,12 @@ import { PeerMap } from '@libp2p/peer-collections' import { peerIdFromString } from '@libp2p/peer-id' import { isMultiaddr, type Multiaddr } from '@multiformats/multiaddr' import { connectionPair } from './connection.js' -import type { Libp2pEvents } from '@libp2p/interface' +import type { Libp2pEvents, PendingDial } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' import type { EventEmitter } from '@libp2p/interface/events' import type { PubSub } from '@libp2p/interface/pubsub' import type { Startable } from '@libp2p/interface/startable' -import type { ConnectionManager, PendingDial } from '@libp2p/interface-internal/connection-manager' +import type { ConnectionManager } from '@libp2p/interface-internal/connection-manager' import type { Registrar } from '@libp2p/interface-internal/registrar' export interface MockNetworkComponents { diff --git a/packages/interface-compliance-tests/src/mocks/connection.ts b/packages/interface-compliance-tests/src/mocks/connection.ts index 05ab576899..cbd99530da 100644 --- a/packages/interface-compliance-tests/src/mocks/connection.ts +++ b/packages/interface-compliance-tests/src/mocks/connection.ts @@ -1,4 +1,3 @@ -import * as STATUS from '@libp2p/interface/connection/status' import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' import * as mss from '@libp2p/multistream-select' @@ -9,8 +8,7 @@ import { mockMultiaddrConnection } from './multiaddr-connection.js' import { mockMuxer } from './muxer.js' import { mockRegistrar } from './registrar.js' import type { AbortOptions } from '@libp2p/interface' -import type { MultiaddrConnection, Connection, Stream, Direction, ConnectionTimeline } from '@libp2p/interface/connection' -import type * as Status from '@libp2p/interface/connection/status' +import type { MultiaddrConnection, Connection, Stream, Direction, ConnectionTimeline, ConnectionStatus } from '@libp2p/interface/connection' import type { PeerId } from '@libp2p/interface/peer-id' import type { StreamMuxer, StreamMuxerFactory } from '@libp2p/interface/stream-muxer' import type { Registrar } from '@libp2p/interface-internal/registrar' @@ -42,7 +40,7 @@ class MockConnection implements Connection { public timeline: ConnectionTimeline public multiplexer?: string public encryption?: string - public status: keyof typeof Status + public status: ConnectionStatus public streams: Stream[] public tags: string[] @@ -56,7 +54,7 @@ class MockConnection implements Connection { this.remoteAddr = remoteAddr this.remotePeer = remotePeer this.direction = direction - this.status = STATUS.OPEN + this.status = 'open' this.direction = direction this.timeline = maConn.timeline this.multiplexer = 'test-multiplexer' @@ -76,7 +74,7 @@ class MockConnection implements Connection { throw new Error('protocols must have a length') } - if (this.status !== STATUS.OPEN) { + if (this.status !== 'open') { throw new CodeError('connection must be open to create streams', 'ERR_CONNECTION_CLOSED') } @@ -84,16 +82,14 @@ class MockConnection implements Connection { const stream = await this.muxer.newStream(id) const result = await mss.select(stream, protocols, options) - const streamWithProtocol: Stream = { - ...stream, - ...result.stream, - direction: 'outbound', - protocol: result.protocol - } + stream.protocol = result.protocol + stream.direction = 'outbound' + stream.sink = result.stream.sink + stream.source = result.stream.source - this.streams.push(streamWithProtocol) + this.streams.push(stream) - return streamWithProtocol + return stream } addStream (stream: Stream): void { @@ -104,13 +100,23 @@ class MockConnection implements Connection { this.streams = this.streams.filter(stream => stream.id !== id) } - async close (): Promise { - this.status = STATUS.CLOSING + async close (options?: AbortOptions): Promise { + this.status = 'closing' + await Promise.all( + this.streams.map(async s => s.close(options)) + ) await this.maConn.close() + this.status = 'closed' + this.timeline.close = Date.now() + } + + abort (err: Error): void { + this.status = 'closing' this.streams.forEach(s => { - s.close() + s.abort(err) }) - this.status = STATUS.CLOSED + this.maConn.abort(err) + this.status = 'closed' this.timeline.close = Date.now() } } @@ -135,8 +141,9 @@ export function mockConnection (maConn: MultiaddrConnection, opts: MockConnectio mss.handle(muxedStream, registrar.getProtocols()) .then(({ stream, protocol }) => { log('%s: incoming stream opened on %s', direction, protocol) - muxedStream = { ...muxedStream, ...stream } muxedStream.protocol = protocol + muxedStream.sink = stream.sink + muxedStream.source = stream.source connection.addStream(muxedStream) const { handler } = registrar.getHandler(protocol) @@ -172,18 +179,20 @@ export function mockConnection (maConn: MultiaddrConnection, opts: MockConnectio export function mockStream (stream: Duplex, Source, Promise>): Stream { return { ...stream, - close: () => {}, - closeRead: () => {}, - closeWrite: () => {}, + close: async () => {}, + closeRead: async () => {}, + closeWrite: async () => {}, abort: () => {}, - reset: () => {}, direction: 'outbound', protocol: '/foo/1.0.0', timeline: { open: Date.now() }, metadata: {}, - id: `stream-${Date.now()}` + id: `stream-${Date.now()}`, + status: 'open', + readStatus: 'ready', + writeStatus: 'ready' } } diff --git a/packages/interface-compliance-tests/src/mocks/multiaddr-connection.ts b/packages/interface-compliance-tests/src/mocks/multiaddr-connection.ts index 5bfc94f6ae..8944a4384b 100644 --- a/packages/interface-compliance-tests/src/mocks/multiaddr-connection.ts +++ b/packages/interface-compliance-tests/src/mocks/multiaddr-connection.ts @@ -11,6 +11,7 @@ export function mockMultiaddrConnection (source: Duplex {}, timeline: { open: Date.now() }, @@ -44,6 +45,10 @@ export function mockMultiaddrConnPair (opts: MockMultiaddrConnPairOptions): { in close: async () => { outbound.timeline.close = Date.now() controller.abort() + }, + abort: (err: Error) => { + outbound.timeline.close = Date.now() + controller.abort(err) } } @@ -56,6 +61,10 @@ export function mockMultiaddrConnPair (opts: MockMultiaddrConnPairOptions): { in close: async () => { inbound.timeline.close = Date.now() controller.abort() + }, + abort: (err: Error) => { + outbound.timeline.close = Date.now() + controller.abort(err) } } diff --git a/packages/interface-compliance-tests/src/mocks/muxer.ts b/packages/interface-compliance-tests/src/mocks/muxer.ts index 592cf95252..e41aa24a4c 100644 --- a/packages/interface-compliance-tests/src/mocks/muxer.ts +++ b/packages/interface-compliance-tests/src/mocks/muxer.ts @@ -1,7 +1,6 @@ -import { CodeError } from '@libp2p/interface/errors' +import { AbstractStream, type AbstractStreamInit } from '@libp2p/interface/stream-muxer/stream' import { type Logger, logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import { anySignal } from 'any-signal' import map from 'it-map' import * as ndjson from 'it-ndjson' import { pipe } from 'it-pipe' @@ -9,254 +8,94 @@ import { type Pushable, pushable } from 'it-pushable' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { toString as uint8ArrayToString } from 'uint8arrays/to-string' -import type { Stream } from '@libp2p/interface/connection' +import type { AbortOptions } from '@libp2p/interface' +import type { Direction, Stream } from '@libp2p/interface/connection' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface/stream-muxer' import type { Source } from 'it-stream-types' let muxers = 0 let streams = 0 -const MAX_MESSAGE_SIZE = 1024 * 1024 interface DataMessage { id: string type: 'data' - direction: 'initiator' | 'recipient' + direction: Direction chunk: string } interface ResetMessage { id: string type: 'reset' - direction: 'initiator' | 'recipient' + direction: Direction } interface CloseMessage { id: string type: 'close' - direction: 'initiator' | 'recipient' + direction: Direction } interface CreateMessage { id: string type: 'create' - direction: 'initiator' + direction: 'outbound' } type StreamMessage = DataMessage | ResetMessage | CloseMessage | CreateMessage -class MuxedStream { - public id: string - public input: Pushable - public stream: Stream - public type: 'initiator' | 'recipient' - - private sinkEnded: boolean - private sourceEnded: boolean - private readonly abortController: AbortController - private readonly resetController: AbortController - private readonly closeController: AbortController - private readonly log: Logger - - constructor (init: { id: string, type: 'initiator' | 'recipient', push: Pushable, onEnd: (err?: Error) => void }) { - const { id, type, push, onEnd } = init - - this.log = logger(`libp2p:mock-muxer:stream:${id}:${type}`) - - this.id = id - this.type = type - this.abortController = new AbortController() - this.resetController = new AbortController() - this.closeController = new AbortController() - - this.sourceEnded = false - this.sinkEnded = false - - let endErr: Error | undefined - - const onSourceEnd = (err?: Error): void => { - if (this.sourceEnded) { - return - } - - this.log('onSourceEnd sink ended? %s', this.sinkEnded) +export interface MockMuxedStreamInit extends AbstractStreamInit { + push: Pushable +} - this.sourceEnded = true +class MuxedStream extends AbstractStream { + private readonly push: Pushable - if (err != null && endErr == null) { - endErr = err - } + constructor (init: MockMuxedStreamInit) { + super(init) - if (this.sinkEnded) { - this.stream.timeline.close = Date.now() + this.push = init.push + } - if (onEnd != null) { - onEnd(endErr) - } - } + sendNewStream (): void { + // If initiator, open a new stream + const createMsg: CreateMessage = { + id: this.id, + type: 'create', + direction: 'outbound' } + this.push.push(createMsg) + } - const onSinkEnd = (err?: Error): void => { - if (this.sinkEnded) { - return - } - - this.log('onSinkEnd source ended? %s', this.sourceEnded) - - this.sinkEnded = true - - if (err != null && endErr == null) { - endErr = err - } - - if (this.sourceEnded) { - this.stream.timeline.close = Date.now() - - if (onEnd != null) { - onEnd(endErr) - } - } + sendData (data: Uint8ArrayList): void { + const dataMsg: DataMessage = { + id: this.id, + type: 'data', + chunk: uint8ArrayToString(data.subarray(), 'base64pad'), + direction: this.direction } + this.push.push(dataMsg) + } - this.input = pushable({ - onEnd: onSourceEnd - }) - - this.stream = { - id, - sink: async (source) => { - if (this.sinkEnded) { - throw new CodeError('stream closed for writing', 'ERR_SINK_ENDED') - } - - const signal = anySignal([ - this.abortController.signal, - this.resetController.signal, - this.closeController.signal - ]) - - source = abortableSource(source, signal) - - try { - if (this.type === 'initiator') { - // If initiator, open a new stream - const createMsg: CreateMessage = { - id: this.id, - type: 'create', - direction: this.type - } - push.push(createMsg) - } - - const list = new Uint8ArrayList() - - for await (const chunk of source) { - list.append(chunk) - - while (list.length > 0) { - const available = Math.min(list.length, MAX_MESSAGE_SIZE) - const dataMsg: DataMessage = { - id, - type: 'data', - chunk: uint8ArrayToString(list.subarray(0, available), 'base64pad'), - direction: this.type - } - - push.push(dataMsg) - list.consume(available) - } - } - } catch (err: any) { - if (err.type === 'aborted' && err.message === 'The operation was aborted') { - if (this.closeController.signal.aborted) { - return - } - - if (this.resetController.signal.aborted) { - err.message = 'stream reset' - err.code = 'ERR_STREAM_RESET' - } - - if (this.abortController.signal.aborted) { - err.message = 'stream aborted' - err.code = 'ERR_STREAM_ABORT' - } - } - - // Send no more data if this stream was remotely reset - if (err.code !== 'ERR_STREAM_RESET') { - const resetMsg: ResetMessage = { - id, - type: 'reset', - direction: this.type - } - push.push(resetMsg) - } - - this.log('sink erred', err) - - this.input.end(err) - onSinkEnd(err) - return - } finally { - signal.clear() - } - - this.log('sink ended') - - onSinkEnd() - - const closeMsg: CloseMessage = { - id, - type: 'close', - direction: this.type - } - push.push(closeMsg) - }, - source: this.input, - - // Close for reading - close: () => { - this.stream.closeRead() - this.stream.closeWrite() - }, - - closeRead: () => { - this.input.end() - }, - - closeWrite: () => { - this.closeController.abort() - - const closeMsg: CloseMessage = { - id, - type: 'close', - direction: this.type - } - push.push(closeMsg) - onSinkEnd() - }, - - // Close for reading and writing (local error) - abort: (err: Error) => { - // End the source with the passed error - this.input.end(err) - this.abortController.abort() - onSinkEnd(err) - }, + sendReset (): void { + const resetMsg: ResetMessage = { + id: this.id, + type: 'reset', + direction: this.direction + } + this.push.push(resetMsg) + } - // Close immediately for reading and writing (remote error) - reset: () => { - const err = new CodeError('stream reset', 'ERR_STREAM_RESET') - this.resetController.abort() - this.input.end(err) - onSinkEnd(err) - }, - direction: type === 'initiator' ? 'outbound' : 'inbound', - timeline: { - open: Date.now() - }, - metadata: {} + sendCloseWrite (): void { + const closeMsg: CloseMessage = { + id: this.id, + type: 'close', + direction: this.direction } + this.push.push(closeMsg) + } + + sendCloseRead (): void { + // does not support close read, only close write } } @@ -284,8 +123,14 @@ class MockMuxer implements StreamMuxer { this.closeController = new AbortController() // receives data from the muxer at the other end of the stream this.source = this.input = pushable({ - onEnd: (err) => { - this.close(err) + onEnd: () => { + for (const stream of this.registryInitiatorStreams.values()) { + stream.destroy() + } + + for (const stream of this.registryRecipientStreams.values()) { + stream.destroy() + } } }) @@ -321,18 +166,18 @@ class MockMuxer implements StreamMuxer { handleMessage (message: StreamMessage): void { let muxedStream: MuxedStream | undefined - const registry = message.direction === 'initiator' ? this.registryRecipientStreams : this.registryInitiatorStreams + const registry = message.direction === 'outbound' ? this.registryRecipientStreams : this.registryInitiatorStreams if (message.type === 'create') { if (registry.has(message.id)) { throw new Error(`Already had stream for ${message.id}`) } - muxedStream = this.createStream(message.id, 'recipient') - registry.set(muxedStream.stream.id, muxedStream) + muxedStream = this.createStream(message.id, 'inbound') + registry.set(muxedStream.id, muxedStream) if (this.options.onIncomingStream != null) { - this.options.onIncomingStream(muxedStream.stream) + this.options.onIncomingStream(muxedStream) } } @@ -345,20 +190,19 @@ class MockMuxer implements StreamMuxer { } if (message.type === 'data') { - muxedStream.input.push(new Uint8ArrayList(uint8ArrayFromString(message.chunk, 'base64pad'))) + muxedStream.sourcePush(new Uint8ArrayList(uint8ArrayFromString(message.chunk, 'base64pad'))) } else if (message.type === 'reset') { - this.log('-> reset stream %s %s', muxedStream.type, muxedStream.stream.id) - muxedStream.stream.reset() + this.log('-> reset stream %s %s', muxedStream.direction, muxedStream.id) + muxedStream.reset() } else if (message.type === 'close') { - this.log('-> closing stream %s %s', muxedStream.type, muxedStream.stream.id) - muxedStream.stream.closeRead() + this.log('-> closing stream %s %s', muxedStream.direction, muxedStream.id) + muxedStream.remoteCloseWrite() } } get streams (): Stream[] { return Array.from(this.registryRecipientStreams.values()) .concat(Array.from(this.registryInitiatorStreams.values())) - .map(({ stream }) => stream) } newStream (name?: string): Stream { @@ -366,53 +210,67 @@ class MockMuxer implements StreamMuxer { throw new Error('Muxer already closed') } this.log('newStream %s', name) - const storedStream = this.createStream(name, 'initiator') - this.registryInitiatorStreams.set(storedStream.stream.id, storedStream) + const storedStream = this.createStream(name, 'outbound') + this.registryInitiatorStreams.set(storedStream.id, storedStream) - return storedStream.stream + return storedStream } - createStream (name?: string, type: 'initiator' | 'recipient' = 'initiator'): MuxedStream { - const id = name ?? `${this.name}:stream:${streams++}` + createStream (name?: string, direction: Direction = 'outbound'): MuxedStream { + const id = name ?? `${streams++}` - this.log('createStream %s %s', type, id) + this.log('createStream %s %s', direction, id) const muxedStream: MuxedStream = new MuxedStream({ id, - type, + direction, push: this.streamInput, onEnd: () => { - this.log('stream ended %s %s', type, id) + this.log('stream ended') - if (type === 'initiator') { - this.registryInitiatorStreams.delete(id) + if (direction === 'outbound') { + this.registryInitiatorStreams.delete(muxedStream.id) } else { - this.registryRecipientStreams.delete(id) + this.registryRecipientStreams.delete(muxedStream.id) } if (this.options.onStreamEnd != null) { - this.options.onStreamEnd(muxedStream.stream) + this.options.onStreamEnd(muxedStream) } - } + }, + log: logger(`libp2p:mock-muxer:stream:${direction}:${id}`) }) return muxedStream } - close (err?: Error): void { - if (this.closeController.signal.aborted) return + async close (options?: AbortOptions): Promise { + if (this.closeController.signal.aborted) { + return + } + this.log('closing muxed streams') - if (err == null) { - this.streams.forEach(s => { - s.close() - }) - } else { - this.streams.forEach(s => { - s.abort(err) - }) - } + await Promise.all( + this.streams.map(async s => s.close()) + ) + this.closeController.abort() + this.input.end() + } + + abort (err: Error): void { + if (this.closeController.signal.aborted) { + return + } + + this.log('aborting muxed streams') + + this.streams.forEach(s => { + s.abort(err) + }) + + this.closeController.abort(err) this.input.end(err) } } diff --git a/packages/interface-compliance-tests/src/stream-muxer/close-test.ts b/packages/interface-compliance-tests/src/stream-muxer/close-test.ts index 8122ba712c..a7ab9b6c54 100644 --- a/packages/interface-compliance-tests/src/stream-muxer/close-test.ts +++ b/packages/interface-compliance-tests/src/stream-muxer/close-test.ts @@ -6,9 +6,12 @@ import all from 'it-all' import drain from 'it-drain' import { duplexPair } from 'it-pair/duplex' import { pipe } from 'it-pipe' +import { pbStream } from 'it-protobuf-stream' +import toBuffer from 'it-to-buffer' import pDefer from 'p-defer' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { Message } from './fixtures/pb/message.js' import type { TestSetup } from '../index.js' import type { StreamMuxerFactory } from '@libp2p/interface/stream-muxer' @@ -61,9 +64,9 @@ export default (common: TestSetup): void => { expect(dialer.streams).to.have.lengthOf(expectedStreams) - // Pause, and then send some data and close the dialer + // Pause, and then close the dialer await delay(50) - await pipe([randomBuffer()], dialer, drain) + await pipe([], dialer, drain) expect(openedStreams).to.have.equal(expectedStreams) expect(dialer.streams).to.have.lengthOf(0) @@ -106,7 +109,7 @@ export default (common: TestSetup): void => { // Pause, and then close the dialer await delay(50) - dialer.close() + await dialer.close() expect(openedStreams, 'listener - number of opened streams should match number of calls to newStream').to.have.equal(expectedStreams) expect(dialer.streams, 'all tracked streams should be deleted after the muxer has called close').to.have.lengthOf(0) @@ -148,7 +151,7 @@ export default (common: TestSetup): void => { await delay(50) // close _with an error_ - dialer.close(new Error()) + dialer.abort(new Error('Oh no!')) const timeoutError = new Error('timeout') for (const pipe of streamPipes) { @@ -173,7 +176,7 @@ export default (common: TestSetup): void => { const dialerFactory = await common.setup() const dialer = dialerFactory.createStreamMuxer({ direction: 'outbound' }) - dialer.close() + await dialer.close() try { await dialer.newStream() @@ -246,7 +249,7 @@ export default (common: TestSetup): void => { onIncomingStream: (stream) => { void Promise.resolve().then(async () => { // Immediate close for write - stream.closeWrite() + await stream.closeWrite() const results = await pipe(stream, async (source) => { const data = [] @@ -275,16 +278,16 @@ export default (common: TestSetup): void => { await stream.sink(data) const err = await deferred.promise - expect(err).to.have.property('message').that.matches(/stream closed for writing/) + expect(err).to.have.property('code', 'ERR_SINK_INVALID_STATE') }) it('can close a stream for reading', async () => { - const deferred = pDefer() - + const deferred = pDefer() const p = duplexPair() const dialerFactory = await common.setup() const dialer = dialerFactory.createStreamMuxer({ direction: 'outbound' }) const data = [randomBuffer(), randomBuffer()].map(d => new Uint8ArrayList(d)) + const expected = toBuffer(data.map(d => d.subarray())) const listenerFactory = await common.setup() const listener = listenerFactory.createStreamMuxer({ @@ -298,7 +301,7 @@ export default (common: TestSetup): void => { void pipe(p[1], listener, p[1]) const stream = await dialer.newStream() - stream.closeRead() + await stream.closeRead() // Source should be done void Promise.resolve().then(async () => { @@ -307,7 +310,7 @@ export default (common: TestSetup): void => { }) const results = await deferred.promise - expect(results).to.eql(data) + expect(toBuffer(results.map(b => b.subarray()))).to.equalBytes(expected) }) it('calls onStreamEnd for closed streams not previously written', async () => { @@ -322,7 +325,7 @@ export default (common: TestSetup): void => { const stream = await dialer.newStream() - stream.close() + await stream.close() await deferred.promise }) @@ -338,9 +341,102 @@ export default (common: TestSetup): void => { const stream = await dialer.newStream() - stream.closeWrite() - stream.closeRead() + await stream.closeWrite() + await stream.closeRead() + await deferred.promise + }) + + it('should wait for all data to be sent when closing streams', async () => { + const deferred = pDefer() + + const p = duplexPair() + const dialerFactory = await common.setup() + const dialer = dialerFactory.createStreamMuxer({ direction: 'outbound' }) + + const listenerFactory = await common.setup() + const listener = listenerFactory.createStreamMuxer({ + direction: 'inbound', + onIncomingStream: (stream) => { + const pb = pbStream(stream) + + void pb.read(Message) + .then(async message => { + deferred.resolve(message) + await pb.unwrap().close() + }) + .catch(err => { + deferred.reject(err) + }) + } + }) + + void pipe(p[0], dialer, p[0]) + void pipe(p[1], listener, p[1]) + + const message = { + message: 'hello world', + value: 5, + flag: true + } + + const stream = await dialer.newStream() + + const pb = pbStream(stream) + await pb.write(message, Message) + await pb.unwrap().close() + + await expect(deferred.promise).to.eventually.deep.equal(message) + }) + /* + it('should abort closing a stream with outstanding data to read', async () => { + const deferred = pDefer() + + const p = duplexPair() + const dialerFactory = await common.setup() + const dialer = dialerFactory.createStreamMuxer({ direction: 'outbound' }) + + const listenerFactory = await common.setup() + const listener = listenerFactory.createStreamMuxer({ + direction: 'inbound', + onIncomingStream: (stream) => { + const pb = pbStream(stream) + + void pb.read(Message) + .then(async message => { + await pb.write(message, Message) + await pb.unwrap().close() + deferred.resolve(message) + }) + .catch(err => { + deferred.reject(err) + }) + } + }) + + void pipe(p[0], dialer, p[0]) + void pipe(p[1], listener, p[1]) + + const message = { + message: 'hello world', + value: 5, + flag: true + } + + const stream = await dialer.newStream() + + const pb = pbStream(stream) + await pb.write(message, Message) + + console.info('await write back') await deferred.promise + + // let message arrive + await delay(100) + + // close should time out as message is never read + await expect(pb.unwrap().close()).to.eventually.be.rejected + .with.property('code', 'ERR_CLOSE_READ_ABORTED') }) + */ }) } diff --git a/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.proto b/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.proto new file mode 100644 index 0000000000..f734b891e3 --- /dev/null +++ b/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +message Message { + string message = 1; + uint32 value = 2; + bool flag = 3; +} diff --git a/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.ts b/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.ts new file mode 100644 index 0000000000..74bdd8bb68 --- /dev/null +++ b/packages/interface-compliance-tests/src/stream-muxer/fixtures/pb/message.ts @@ -0,0 +1,87 @@ +/* eslint-disable import/export */ +/* eslint-disable complexity */ +/* eslint-disable @typescript-eslint/no-namespace */ +/* eslint-disable @typescript-eslint/no-unnecessary-boolean-literal-compare */ +/* eslint-disable @typescript-eslint/no-empty-interface */ + +import { encodeMessage, decodeMessage, message } from 'protons-runtime' +import type { Codec } from 'protons-runtime' +import type { Uint8ArrayList } from 'uint8arraylist' + +export interface Message { + message: string + value: number + flag: boolean +} + +export namespace Message { + let _codec: Codec + + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, w, opts = {}) => { + if (opts.lengthDelimited !== false) { + w.fork() + } + + if ((obj.message != null && obj.message !== '')) { + w.uint32(10) + w.string(obj.message) + } + + if ((obj.value != null && obj.value !== 0)) { + w.uint32(16) + w.uint32(obj.value) + } + + if ((obj.flag != null && obj.flag !== false)) { + w.uint32(24) + w.bool(obj.flag) + } + + if (opts.lengthDelimited !== false) { + w.ldelim() + } + }, (reader, length) => { + const obj: any = { + message: '', + value: 0, + flag: false + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.message = reader.string() + break + case 2: + obj.value = reader.uint32() + break + case 3: + obj.flag = reader.bool() + break + default: + reader.skipType(tag & 7) + break + } + } + + return obj + }) + } + + return _codec + } + + export const encode = (obj: Partial): Uint8Array => { + return encodeMessage(obj, Message.codec()) + } + + export const decode = (buf: Uint8Array | Uint8ArrayList): Message => { + return decodeMessage(buf, Message.codec()) + } +} diff --git a/packages/interface-compliance-tests/src/stream-muxer/spawner.ts b/packages/interface-compliance-tests/src/stream-muxer/spawner.ts index 140ce234c0..f6df6b5c02 100644 --- a/packages/interface-compliance-tests/src/stream-muxer/spawner.ts +++ b/packages/interface-compliance-tests/src/stream-muxer/spawner.ts @@ -19,9 +19,10 @@ export default async (createMuxer: (init?: StreamMuxerInit) => Promise { - stream.close() + ).then(async () => { + await stream.close() }) + .catch(err => { stream.abort(err) }) } }) const dialer = await createMuxer({ direction: 'outbound' }) diff --git a/packages/interface-compliance-tests/src/stream-muxer/stress-test.ts b/packages/interface-compliance-tests/src/stream-muxer/stress-test.ts index cd2f4d4d1c..3da0ed22f0 100644 --- a/packages/interface-compliance-tests/src/stream-muxer/stress-test.ts +++ b/packages/interface-compliance-tests/src/stream-muxer/stress-test.ts @@ -9,7 +9,7 @@ export default (common: TestSetup): void => { } describe('stress test', function () { - this.timeout(800000) + this.timeout(1600000) it('1 stream with 1 msg', async () => { await spawn(createMuxer, 1, 1) }) it('1 stream with 10 msg', async () => { await spawn(createMuxer, 1, 10) }) diff --git a/packages/interface-compliance-tests/src/transport/listen-test.ts b/packages/interface-compliance-tests/src/transport/listen-test.ts index 8f9bc3b3ce..d518511dc7 100644 --- a/packages/interface-compliance-tests/src/transport/listen-test.ts +++ b/packages/interface-compliance-tests/src/transport/listen-test.ts @@ -93,7 +93,7 @@ export default (common: TestSetup): void => { listener.close() ]) - stream1.close() + await stream1.close() await conn1.close() expect(isValidTick(conn1.timeline.close)).to.equal(true) diff --git a/packages/interface-internal/src/connection-manager/index.ts b/packages/interface-internal/src/connection-manager/index.ts index a03569c237..faa6fc94a5 100644 --- a/packages/interface-internal/src/connection-manager/index.ts +++ b/packages/interface-internal/src/connection-manager/index.ts @@ -1,18 +1,9 @@ -import type { AbortOptions } from '@libp2p/interface' +import type { AbortOptions, PendingDial } from '@libp2p/interface' import type { Connection, MultiaddrConnection } from '@libp2p/interface/connection' import type { PeerId } from '@libp2p/interface/peer-id' import type { PeerMap } from '@libp2p/peer-collections' import type { Multiaddr } from '@multiformats/multiaddr' -export type PendingDialStatus = 'queued' | 'active' | 'error' | 'success' - -export interface PendingDial { - id: string - status: PendingDialStatus - peerId?: PeerId - multiaddrs: Multiaddr[] -} - export interface ConnectionManager { /** * Return connections, optionally filtering by a PeerId @@ -51,7 +42,7 @@ export interface ConnectionManager { /** * Close our connections to a peer */ - closeConnections: (peer: PeerId) => Promise + closeConnections: (peer: PeerId, options?: AbortOptions) => Promise /** * Invoked after an incoming connection is opened but before PeerIds are diff --git a/packages/interface-internal/src/index.ts b/packages/interface-internal/src/index.ts index 68753a6328..59ad568a72 100644 --- a/packages/interface-internal/src/index.ts +++ b/packages/interface-internal/src/index.ts @@ -1,598 +1,2 @@ -/** - * @packageDocumentation - * - * Exports a `Libp2p` type for modules to use as a type argument. - * - * @example - * - * ```typescript - * import type { Libp2p } from '@libp2p/interface' - * - * function doSomethingWithLibp2p (node: Libp2p) { - * // ... - * } - * ``` - */ -import type { StreamHandler, StreamHandlerOptions } from './registrar/index.js' -import type { AbortOptions } from '@libp2p/interface' -import type { Connection, Stream } from '@libp2p/interface/connection' -import type { ContentRouting } from '@libp2p/interface/content-routing' -import type { EventEmitter } from '@libp2p/interface/events' -import type { KeyChain } from '@libp2p/interface/keychain' -import type { Metrics } from '@libp2p/interface/metrics' -import type { PeerId } from '@libp2p/interface/peer-id' -import type { PeerInfo } from '@libp2p/interface/peer-info' -import type { PeerRouting } from '@libp2p/interface/peer-routing' -import type { Address, Peer, PeerStore } from '@libp2p/interface/peer-store' -import type { Startable } from '@libp2p/interface/startable' -import type { Topology } from '@libp2p/interface/topology' -import type { Listener } from '@libp2p/interface/transport' -import type { Multiaddr } from '@multiformats/multiaddr' - -/** - * Used by the connection manager to sort addresses into order before dialling - */ -export interface AddressSorter { - (a: Address, b: Address): -1 | 0 | 1 -} - -/** - * Event detail emitted when peer data changes - */ -export interface PeerUpdate { - peer: Peer - previous?: Peer -} - -/** - * Peer data signed by the remote Peer's public key - */ -export interface SignedPeerRecord { - addresses: Multiaddr[] - seq: bigint -} - -/** - * Data returned from a successful identify response - */ -export interface IdentifyResult { - /** - * The remote Peer's PeerId - */ - peerId: PeerId - - /** - * The unsigned addresses they are listening on. Note - any multiaddrs present - * in the signed peer record should be preferred to the value here. - */ - listenAddrs: Multiaddr[] - - /** - * The protocols the remote peer supports - */ - protocols: string[] - - /** - * The remote protocol version - */ - protocolVersion?: string - - /** - * The remote agent version - */ - agentVersion?: string - - /** - * The public key part of the remote PeerId - this is only useful for older - * RSA-based PeerIds, the more modern Ed25519 and secp256k1 types have the - * public key embedded in them - */ - publicKey?: Uint8Array - - /** - * If set this is the address that the remote peer saw the identify request - * originate from - */ - observedAddr?: Multiaddr - - /** - * If sent by the remote peer this is the deserialized signed peer record - */ - signedPeerRecord?: SignedPeerRecord -} - -/** - * Once you have a libp2p instance, you can listen to several events it emits, - * so that you can be notified of relevant network events. - * - * Event names are `noun:verb` so the first part is the name of the object - * being acted on and the second is the action. - */ -export interface Libp2pEvents { - /** - * This event is dispatched when a new network peer is discovered. - * - * @example - * - * ```js - * libp2p.addEventListener('peer:discovery', (event) => { - * const peerInfo = event.detail - * // ... - * }) - * ``` - */ - 'peer:discovery': CustomEvent - - /** - * This event will be triggered any time a new peer connects. - * - * @example - * - * ```js - * libp2p.addEventListener('peer:connect', (event) => { - * const peerId = event.detail - * // ... - * }) - * ``` - */ - 'peer:connect': CustomEvent - - /** - * This event will be triggered any time we are disconnected from another peer, regardless of - * the circumstances of that disconnection. If we happen to have multiple connections to a - * peer, this event will **only** be triggered when the last connection is closed. - * - * @example - * - * ```js - * libp2p.addEventListener('peer:disconnect', (event) => { - * const peerId = event.detail - * // ... - * }) - * ``` - */ - 'peer:disconnect': CustomEvent - - /** - * This event is dispatched after a remote peer has successfully responded to the identify - * protocol. Note that for this to be emitted, both peers must have an identify service - * configured. - * - * @example - * - * ```js - * libp2p.addEventListener('peer:identify', (event) => { - * const identifyResult = event.detail - * // ... - * }) - * ``` - */ - 'peer:identify': CustomEvent - - /** - * This event is dispatched when the peer store data for a peer has been - * updated - e.g. their multiaddrs, protocols etc have changed. - * - * If they were previously known to this node, the old peer data will be - * set in the `previous` field. - * - * This may be in response to the identify protocol running, a manual - * update or some other event. - */ - 'peer:update': CustomEvent - - /** - * This event is dispatched when the current node's peer record changes - - * for example a transport started listening on a new address or a new - * protocol handler was registered. - * - * @example - * - * ```js - * libp2p.addEventListener('self:peer:update', (event) => { - * const { peer } = event.detail - * // ... - * }) - * ``` - */ - 'self:peer:update': CustomEvent - - /** - * This event is dispatched when a transport begins listening on a new address - */ - 'transport:listening': CustomEvent - - /** - * This event is dispatched when a transport stops listening on an address - */ - 'transport:close': CustomEvent - - /** - * This event is dispatched when the connection manager has more than the - * configured allowable max connections and has closed some connections to - * bring the node back under the limit. - */ - 'connection:prune': CustomEvent - - /** - * This event notifies listeners when new incoming or outgoing connections - * are opened. - */ - 'connection:open': CustomEvent - - /** - * This event notifies listeners when incoming or outgoing connections are - * closed. - */ - 'connection:close': CustomEvent - - /** - * This event notifies listeners that the node has started - * - * ```js - * libp2p.addEventListener('start', (event) => { - * console.info(libp2p.isStarted()) // true - * }) - * ``` - */ - 'start': CustomEvent> - - /** - * This event notifies listeners that the node has stopped - * - * ```js - * libp2p.addEventListener('stop', (event) => { - * console.info(libp2p.isStarted()) // false - * }) - * ``` - */ - 'stop': CustomEvent> -} - -/** - * A map of user defined services available on the libp2p node via the - * `services` key - * - * @example - * - * ```js - * const node = await createLibp2p({ - * // ...other options - * services: { - * myService: myService({ - * // ...service options - * }) - * } - * }) - * - * // invoke methods on the service - * node.services.myService.anOperation() - * ``` - */ -export type ServiceMap = Record - -export type PendingDialStatus = 'queued' | 'active' | 'error' | 'success' - -/** - * An item in the dial queue - */ -export interface PendingDial { - /** - * A unique identifier for this dial - */ - id: string - - /** - * The current status of the dial - */ - status: PendingDialStatus - - /** - * If known, this is the peer id that libp2p expects to be dialling - */ - peerId?: PeerId - - /** - * The list of multiaddrs that will be dialled. The returned connection will - * use the first address that succeeds, all other dials part of this pending - * dial will be cancelled. - */ - multiaddrs: Multiaddr[] -} - -/** - * Libp2p nodes implement this interface. - */ -export interface Libp2p extends Startable, EventEmitter> { - /** - * The PeerId is a unique identifier for a node on the network. - * - * It is the hash of an RSA public key or, for Ed25519 or secp256k1 keys, - * the key itself. - * - * @example - * - * ```js - * console.info(libp2p.peerId) - * // PeerId(12D3Foo...) - * ```` - */ - peerId: PeerId - - /** - * The peer store holds information we know about other peers on the network. - * - multiaddrs, supported protocols, etc. - * - * @example - * - * ```js - * const peer = await libp2p.peerStore.get(peerId) - * console.info(peer) - * // { id: PeerId(12D3Foo...), addresses: [] ... } - * ``` - */ - peerStore: PeerStore - - /** - * The peer routing subsystem allows the user to find peers on the network - * or to find peers close to binary keys. - * - * @example - * - * ```js - * const peerInfo = await libp2p.peerRouting.findPeer(peerId) - * console.info(peerInfo) - * // { id: PeerId(12D3Foo...), multiaddrs: [] ... } - * ``` - * - * @example - * - * ```js - * for await (const peerInfo of libp2p.peerRouting.getClosestPeers(key)) { - * console.info(peerInfo) - * // { id: PeerId(12D3Foo...), multiaddrs: [] ... } - * } - * ``` - */ - peerRouting: PeerRouting - - /** - * The content routing subsystem allows the user to find providers for content, - * let the network know they are providers for content, and get/put values to - * the DHT. - * - * @example - * - * ```js - * for await (const peerInfo of libp2p.contentRouting.findProviders(cid)) { - * console.info(peerInfo) - * // { id: PeerId(12D3Foo...), multiaddrs: [] ... } - * } - * ``` - */ - contentRouting: ContentRouting - - /** - * The keychain contains the keys used by the current node, and can create new - * keys, export them, import them, etc. - * - * @example - * - * ```js - * const keyInfo = await libp2p.keychain.createKey('new key') - * console.info(keyInfo) - * // { id: '...', name: 'new key' } - * ``` - */ - keychain: KeyChain - - /** - * The metrics subsystem allows recording values to assess the health/performance - * of the running node. - * - * @example - * - * ```js - * const metric = libp2p.metrics.registerMetric({ - * 'my-metric' - * }) - * - * // later - * metric.update(5) - * ``` - */ - metrics?: Metrics - - /** - * Get a deduplicated list of peer advertising multiaddrs by concatenating - * the listen addresses used by transports with any configured - * announce addresses as well as observed addresses reported by peers. - * - * If Announce addrs are specified, configured listen addresses will be - * ignored though observed addresses will still be included. - * - * @example - * - * ```js - * const listenMa = libp2p.getMultiaddrs() - * // [ ] - * ``` - */ - getMultiaddrs: () => Multiaddr[] - - /** - * Returns a list of supported protocols - * - * @example - * - * ```js - * const protocols = libp2p.getProtocols() - * // [ '/ipfs/ping/1.0.0', '/ipfs/id/1.0.0' ] - * ``` - */ - getProtocols: () => string[] - - /** - * Return a list of all connections this node has open, optionally filtering - * by a PeerId - * - * @example - * - * ```js - * for (const connection of libp2p.getConnections()) { - * console.log(peerId, connection.remoteAddr.toString()) - * // Logs the PeerId string and the observed remote multiaddr of each Connection - * } - * ``` - */ - getConnections: (peerId?: PeerId) => Connection[] - - /** - * Return the list of dials currently in progress or queued to start - * - * @example - * - * ```js - * for (const pendingDial of libp2p.getDialQueue()) { - * console.log(pendingDial) - * } - * ``` - */ - getDialQueue: () => PendingDial[] - - /** - * Return a list of all peers we currently have a connection open to - */ - getPeers: () => PeerId[] - - /** - * Dials to the provided peer. If successful, the known metadata of the - * peer will be added to the nodes `peerStore`. - * - * If a PeerId is passed as the first argument, the peer will need to have known multiaddrs for it in the PeerStore. - * - * @example - * - * ```js - * const conn = await libp2p.dial(remotePeerId) - * - * // create a new stream within the connection - * const { stream, protocol } = await conn.newStream(['/echo/1.1.0', '/echo/1.0.0']) - * - * // protocol negotiated: 'echo/1.0.0' means that the other party only supports the older version - * - * // ... - * await conn.close() - * ``` - */ - dial: (peer: PeerId | Multiaddr | Multiaddr[], options?: AbortOptions) => Promise - - /** - * Dials to the provided peer and tries to handshake with the given protocols in order. - * If successful, the known metadata of the peer will be added to the nodes `peerStore`, - * and the `MuxedStream` will be returned together with the successful negotiated protocol. - * - * @example - * - * ```js - * import { pipe } from 'it-pipe' - * - * const { stream, protocol } = await libp2p.dialProtocol(remotePeerId, protocols) - * - * // Use this new stream like any other duplex stream - * pipe([1, 2, 3], stream, consume) - * ``` - */ - dialProtocol: (peer: PeerId | Multiaddr | Multiaddr[], protocols: string | string[], options?: AbortOptions) => Promise - - /** - * Attempts to gracefully close an open connection to the given peer. If the connection is not closed in the grace period, it will be forcefully closed. - * - * @example - * - * ```js - * await libp2p.hangUp(remotePeerId) - * ``` - */ - hangUp: (peer: PeerId | Multiaddr) => Promise - - /** - * Sets up [multistream-select routing](https://github.com/multiformats/multistream-select) of protocols to their application handlers. Whenever a stream is opened on one of the provided protocols, the handler will be called. `handle` must be called in order to register a handler and support for a given protocol. This also informs other peers of the protocols you support. - * - * `libp2p.handle(protocols, handler, options)` - * - * In the event of a new handler for the same protocol being added, the first one is discarded. - * - * @example - * - * ```js - * const handler = ({ connection, stream, protocol }) => { - * // use stream or connection according to the needs - * } - * - * libp2p.handle('/echo/1.0.0', handler, { - * maxInboundStreams: 5, - * maxOutboundStreams: 5 - * }) - * ``` - */ - handle: (protocol: string | string[], handler: StreamHandler, options?: StreamHandlerOptions) => Promise - - /** - * Removes the handler for each protocol. The protocol - * will no longer be supported on streams. - * - * @example - * - * ```js - * libp2p.unhandle(['/echo/1.0.0']) - * ``` - */ - unhandle: (protocols: string[] | string) => Promise - - /** - * Register a topology to be informed when peers are encountered that - * support the specified protocol - * - * @example - * - * ```js - * const id = await libp2p.register('/echo/1.0.0', { - * onConnect: (peer, connection) => { - * // handle connect - * }, - * onDisconnect: (peer, connection) => { - * // handle disconnect - * } - * }) - * ``` - */ - register: (protocol: string, topology: Topology) => Promise - - /** - * Unregister topology to no longer be informed when peers connect or - * disconnect. - * - * @example - * - * ```js - * const id = await libp2p.register(...) - * - * libp2p.unregister(id) - * ``` - */ - unregister: (id: string) => void - - /** - * Returns the public key for the passed PeerId. If the PeerId is of the 'RSA' type - * this may mean searching the DHT if the key is not present in the KeyStore. - * A set of user defined services - */ - getPublicKey: (peer: PeerId, options?: AbortOptions) => Promise - - /** - * A set of user defined services - */ - services: T -} +export {} diff --git a/packages/interface/package.json b/packages/interface/package.json index 112f9c5ee0..48976e7f8d 100644 --- a/packages/interface/package.json +++ b/packages/interface/package.json @@ -56,10 +56,6 @@ "types": "./dist/src/connection-gater/index.d.ts", "import": "./dist/src/connection-gater/index.js" }, - "./connection/status": { - "types": "./dist/src/connection/status.d.ts", - "import": "./dist/src/connection/status.js" - }, "./content-routing": { "types": "./dist/src/content-routing/index.d.ts", "import": "./dist/src/content-routing/index.js" @@ -163,10 +159,10 @@ "dependencies": { "@multiformats/multiaddr": "^12.1.3", "abortable-iterator": "^5.0.1", - "any-signal": "^4.1.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "it-stream-types": "^2.0.1", "multiformats": "^12.0.1", + "p-defer": "^4.0.0", "uint8arraylist": "^2.4.3" }, "devDependencies": { diff --git a/packages/interface/src/connection/index.ts b/packages/interface/src/connection/index.ts index fe4dfe0831..07ffa2b7ac 100644 --- a/packages/interface/src/connection/index.ts +++ b/packages/interface/src/connection/index.ts @@ -1,4 +1,3 @@ -import type * as Status from './status.js' import type { AbortOptions } from '../index.js' import type { PeerId } from '../peer-id/index.js' import type { Multiaddr } from '@multiformats/multiaddr' @@ -6,8 +5,20 @@ import type { Duplex, Source } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' export interface ConnectionTimeline { + /** + * When the connection was opened + */ open: number + + /** + * When the MultiaddrConnection was upgraded to a Connection - e.g. the type + * of connection encryption and multiplexing was negotiated. + */ upgraded?: number + + /** + * When the connection was closed. + */ close?: number } @@ -41,8 +52,38 @@ export interface StreamTimeline { * A timestamp of when the stream was reset */ reset?: number + + /** + * A timestamp of when the stream was aborted + */ + abort?: number } +/** + * The states a stream can be in + */ +export type StreamStatus = 'open' | 'closing' | 'closed' | 'aborted' | 'reset' + +/** + * The states the readable end of a stream can be in + * + * ready - the readable end is ready for reading + * closing - the readable end is closing + * closed - the readable end has closed + */ +export type ReadStatus = 'ready' | 'closing' | 'closed' + +/** + * The states the writable end of a stream can be in + * + * ready - the writable end is ready for writing + * writing - the writable end is in the process of being written to + * done - the source passed to the `.sink` function yielded all values without error + * closing - the writable end is closing + * closed - the writable end has closed + */ +export type WriteStatus = 'ready' | 'writing' | 'done' | 'closing' | 'closed' + /** * A Stream is a data channel between two peers that * can be written to and read from at both ends. @@ -60,7 +101,7 @@ export interface Stream extends Duplex, Source void + close: (options?: AbortOptions) => Promise /** * Closes the stream for **reading**. If iterating over the source of this stream in a `for await of` loop, it will return (exit the loop) after any buffered data has been consumed. @@ -69,14 +110,14 @@ export interface Stream extends Duplex, Source void + closeRead: (options?: AbortOptions) => Promise /** * Closes the stream for **writing**. If iterating over the source of this stream in a `for await of` loop, it will return (exit the loop) after any buffered data has been consumed. * * The source will return normally, the sink will continue to consume. */ - closeWrite: () => void + closeWrite: (options?: AbortOptions) => Promise /** * Closes the stream for **reading** *and* **writing**. This should be called when a *local error* has occurred. @@ -89,15 +130,6 @@ export interface Stream extends Duplex, Source void - /** - * Closes the stream *immediately* for **reading** *and* **writing**. This should be called when a *remote error* has occurred. - * - * This function is called automatically by the muxer when it receives a `RESET` message from the remote. - * - * The sink will return and the source will throw. - */ - reset: () => void - /** * Unique identifier for a stream. Identifiers are not unique across muxers. */ @@ -122,6 +154,21 @@ export interface Stream extends Duplex, Source + + /** + * The current status of the stream + */ + status: StreamStatus + + /** + * The current status of the readable end of the stream + */ + readStatus: ReadStatus + + /** + * The current status of the writable end of the stream + */ + writeStatus: WriteStatus } export interface NewStreamOptions extends AbortOptions { @@ -133,6 +180,8 @@ export interface NewStreamOptions extends AbortOptions { maxOutboundStreams?: number } +export type ConnectionStatus = 'open' | 'closing' | 'closed' + /** * A Connection is a high-level representation of a connection * to a remote peer that may have been secured by encryption and @@ -188,12 +237,33 @@ export interface Connection { /** * The current status of the connection */ - status: keyof typeof Status + status: ConnectionStatus - newStream: (multicodecs: string | string[], options?: NewStreamOptions) => Promise + /** + * Create a new stream on this connection and negotiate one of the passed protocols + */ + newStream: (protocols: string | string[], options?: NewStreamOptions) => Promise + + /** + * Add a stream to this connection + */ addStream: (stream: Stream) => void + + /** + * Remove a stream from this connection + */ removeStream: (id: string) => void - close: () => Promise + + /** + * Gracefully close the connection. All queued data will be written to the + * underlying transport. + */ + close: (options?: AbortOptions) => Promise + + /** + * Immediately close the connection, any queued data will be discarded + */ + abort: (err: Error) => void } export const symbol = Symbol.for('@libp2p/connection') @@ -203,7 +273,6 @@ export function isConnection (other: any): other is Connection { } export interface ConnectionProtector { - /** * Takes a given Connection and creates a private encryption stream * between its two peers from the PSK the Protector instance was @@ -213,8 +282,20 @@ export interface ConnectionProtector { } export interface MultiaddrConnectionTimeline { + /** + * When the connection was opened + */ open: number + + /** + * When the MultiaddrConnection was upgraded to a Connection - the type of + * connection encryption and multiplexing was negotiated. + */ upgraded?: number + + /** + * When the connection was closed. + */ close?: number } @@ -224,7 +305,24 @@ export interface MultiaddrConnectionTimeline { * without encryption or stream multiplexing. */ export interface MultiaddrConnection extends Duplex, Source, Promise> { - close: (err?: Error) => Promise + /** + * Gracefully close the connection. All queued data will be written to the + * underlying transport. + */ + close: (options?: AbortOptions) => Promise + + /** + * Immediately close the connection, any queued data will be discarded + */ + abort: (err: Error) => void + + /** + * The address of the remote end of the connection + */ remoteAddr: Multiaddr + + /** + * When connection lifecycle events occurred + */ timeline: MultiaddrConnectionTimeline } diff --git a/packages/interface/src/connection/status.ts b/packages/interface/src/connection/status.ts deleted file mode 100644 index a640d97e03..0000000000 --- a/packages/interface/src/connection/status.ts +++ /dev/null @@ -1,4 +0,0 @@ - -export const OPEN = 'OPEN' -export const CLOSING = 'CLOSING' -export const CLOSED = 'CLOSED' diff --git a/packages/interface/src/index.ts b/packages/interface/src/index.ts index d530e5e764..91860040b0 100644 --- a/packages/interface/src/index.ts +++ b/packages/interface/src/index.ts @@ -506,7 +506,11 @@ export interface Libp2p extends Startable, Ev dialProtocol: (peer: PeerId | Multiaddr | Multiaddr[], protocols: string | string[], options?: AbortOptions) => Promise /** - * Attempts to gracefully close an open connection to the given peer. If the connection is not closed in the grace period, it will be forcefully closed. + * Attempts to gracefully close an open connection to the given peer. If the + * connection is not closed in the grace period, it will be forcefully closed. + * + * An AbortSignal can optionally be passed to control when the connection is + * forcefully closed. * * @example * @@ -514,7 +518,7 @@ export interface Libp2p extends Startable, Ev * await libp2p.hangUp(remotePeerId) * ``` */ - hangUp: (peer: PeerId | Multiaddr) => Promise + hangUp: (peer: PeerId | Multiaddr, options?: AbortOptions) => Promise /** * Sets up [multistream-select routing](https://github.com/multiformats/multistream-select) of protocols to their application handlers. Whenever a stream is opened on one of the provided protocols, the handler will be called. `handle` must be called in order to register a handler and support for a given protocol. This also informs other peers of the protocols you support. diff --git a/packages/interface/src/stream-muxer/index.ts b/packages/interface/src/stream-muxer/index.ts index c4861fceb5..fba7b9da34 100644 --- a/packages/interface/src/stream-muxer/index.ts +++ b/packages/interface/src/stream-muxer/index.ts @@ -37,10 +37,15 @@ export interface StreamMuxer extends Duplex, Source void + close: (options?: AbortOptions) => Promise + + /** + * Close or abort all tracked streams and stop the muxer + */ + abort: (err: Error) => void } -export interface StreamMuxerInit extends AbortOptions { +export interface StreamMuxerInit { /** * A callback function invoked every time an incoming stream is opened */ diff --git a/packages/interface/src/stream-muxer/stream.ts b/packages/interface/src/stream-muxer/stream.ts index 8767b3b10f..5027fa0ebd 100644 --- a/packages/interface/src/stream-muxer/stream.ts +++ b/packages/interface/src/stream-muxer/stream.ts @@ -1,22 +1,21 @@ -// import { logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import { anySignal } from 'any-signal' import { type Pushable, pushable } from 'it-pushable' +import defer, { type DeferredPromise } from 'p-defer' import { Uint8ArrayList } from 'uint8arraylist' import { CodeError } from '../errors.js' -import type { Direction, Stream, StreamTimeline } from '../connection/index.js' +import type { Direction, ReadStatus, Stream, StreamStatus, StreamTimeline, WriteStatus } from '../connection/index.js' +import type { AbortOptions } from '../index.js' import type { Source } from 'it-stream-types' -// const log = logger('libp2p:stream') - -const log: any = () => {} -log.trace = () => {} -log.error = () => {} +interface Logger { + (formatter: any, ...args: any[]): void + error: (formatter: any, ...args: any[]) => void + trace: (formatter: any, ...args: any[]) => void + enabled: boolean +} const ERR_STREAM_RESET = 'ERR_STREAM_RESET' -const ERR_STREAM_ABORT = 'ERR_STREAM_ABORT' -const ERR_SINK_ENDED = 'ERR_SINK_ENDED' -const ERR_DOUBLE_SINK = 'ERR_DOUBLE_SINK' +const ERR_SINK_INVALID_STATE = 'ERR_SINK_INVALID_STATE' export interface AbstractStreamInit { /** @@ -30,10 +29,9 @@ export interface AbstractStreamInit { direction: Direction /** - * The maximum allowable data size, any data larger than this will be - * chunked and sent in multiple data messages + * A Logger implementation used to log stream-specific information */ - maxDataSize: number + log: Logger /** * User specific stream metadata @@ -44,6 +42,32 @@ export interface AbstractStreamInit { * Invoked when the stream ends */ onEnd?: (err?: Error | undefined) => void + + /** + * Invoked when the readable end of the stream is closed + */ + onCloseRead?: () => void + + /** + * Invoked when the writable end of the stream is closed + */ + onCloseWrite?: () => void + + /** + * Invoked when the the stream has been reset by the remote + */ + onReset?: () => void + + /** + * Invoked when the the stream has errored + */ + onAbort?: (err: Error) => void + + /** + * How long to wait in ms for stream data to be written to the underlying + * connection when closing the writable end of the stream. (default: 500) + */ + closeTimeout?: number } function isPromise (res?: any): res is Promise { @@ -57,25 +81,31 @@ export abstract class AbstractStream implements Stream { public protocol?: string public metadata: Record public source: AsyncGenerator + public status: StreamStatus + public readStatus: ReadStatus + public writeStatus: WriteStatus - private readonly abortController: AbortController - private readonly resetController: AbortController - private readonly closeController: AbortController - private sourceEnded: boolean - private sinkEnded: boolean - private sinkSunk: boolean + private readonly sinkController: AbortController + private readonly sinkEnd: DeferredPromise private endErr: Error | undefined private readonly streamSource: Pushable private readonly onEnd?: (err?: Error | undefined) => void - private readonly maxDataSize: number + private readonly onCloseRead?: () => void + private readonly onCloseWrite?: () => void + private readonly onReset?: () => void + private readonly onAbort?: (err: Error) => void + + protected readonly log: Logger constructor (init: AbstractStreamInit) { - this.abortController = new AbortController() - this.resetController = new AbortController() - this.closeController = new AbortController() - this.sourceEnded = false - this.sinkEnded = false - this.sinkSunk = false + this.sinkController = new AbortController() + this.sinkEnd = defer() + this.log = init.log + + // stream status + this.status = 'open' + this.readStatus = 'ready' + this.writeStatus = 'ready' this.id = init.id this.metadata = init.metadata ?? {} @@ -83,23 +113,23 @@ export abstract class AbstractStream implements Stream { this.timeline = { open: Date.now() } - this.maxDataSize = init.maxDataSize + this.onEnd = init.onEnd + this.onCloseRead = init?.onCloseRead + this.onCloseWrite = init?.onCloseWrite + this.onReset = init?.onReset + this.onAbort = init?.onAbort this.source = this.streamSource = pushable({ - onEnd: () => { - // already sent a reset message - if (this.timeline.reset !== null) { - const res = this.sendCloseRead() - - if (isPromise(res)) { - res.catch(err => { - log.error('error while sending close read', err) - }) - } + onEnd: (err) => { + if (err != null) { + this.log.trace('source ended with error', err) + } else { + this.log.trace('source ended') } - this.onSourceEnd() + this.readStatus = 'closed' + this.onSourceEnd(err) } }) @@ -107,216 +137,295 @@ export abstract class AbstractStream implements Stream { this.sink = this.sink.bind(this) } + async sink (source: Source): Promise { + if (this.writeStatus !== 'ready') { + throw new CodeError(`writable end state is "${this.writeStatus}" not "ready"`, ERR_SINK_INVALID_STATE) + } + + try { + this.writeStatus = 'writing' + + const options: AbortOptions = { + signal: this.sinkController.signal + } + + if (this.direction === 'outbound') { // If initiator, open a new stream + const res = this.sendNewStream(options) + + if (isPromise(res)) { + await res + } + } + + source = abortableSource(source, this.sinkController.signal, { + returnOnAbort: true + }) + + this.log.trace('sink reading from source') + + for await (let data of source) { + data = data instanceof Uint8Array ? new Uint8ArrayList(data) : data + + const res = this.sendData(data, options) + + if (isPromise(res)) { // eslint-disable-line max-depth + await res + } + } + + this.log.trace('sink finished reading from source') + this.writeStatus = 'done' + + this.log.trace('sink calling closeWrite') + await this.closeWrite(options) + this.onSinkEnd() + } catch (err: any) { + this.log.trace('sink ended with error, calling abort with error', err) + this.abort(err) + + throw err + } finally { + this.log.trace('resolve sink end') + this.sinkEnd.resolve() + } + } + protected onSourceEnd (err?: Error): void { - if (this.sourceEnded) { + if (this.timeline.closeRead != null) { return } this.timeline.closeRead = Date.now() - this.sourceEnded = true - log.trace('%s stream %s source end - err: %o', this.direction, this.id, err) if (err != null && this.endErr == null) { this.endErr = err } - if (this.sinkEnded) { + this.onCloseRead?.() + + if (this.timeline.closeWrite != null) { + this.log.trace('source and sink ended') this.timeline.close = Date.now() if (this.onEnd != null) { this.onEnd(this.endErr) } + } else { + this.log.trace('source ended, waiting for sink to end') } } protected onSinkEnd (err?: Error): void { - if (this.sinkEnded) { + if (this.timeline.closeWrite != null) { return } this.timeline.closeWrite = Date.now() - this.sinkEnded = true - log.trace('%s stream %s sink end - err: %o', this.direction, this.id, err) if (err != null && this.endErr == null) { this.endErr = err } - if (this.sourceEnded) { + this.onCloseWrite?.() + + if (this.timeline.closeRead != null) { + this.log.trace('sink and source ended') this.timeline.close = Date.now() if (this.onEnd != null) { this.onEnd(this.endErr) } + } else { + this.log.trace('sink ended, waiting for source to end') } } // Close for both Reading and Writing - close (): void { - log.trace('%s stream %s close', this.direction, this.id) + async close (options?: AbortOptions): Promise { + this.log.trace('closing gracefully') - this.closeRead() - this.closeWrite() - } + this.status = 'closing' - // Close for reading - closeRead (): void { - log.trace('%s stream %s closeRead', this.direction, this.id) + await Promise.all([ + this.closeRead(options), + this.closeWrite(options) + ]) - if (this.sourceEnded) { - return - } + this.status = 'closed' - this.streamSource.end() + this.log.trace('closed gracefully') } - // Close for writing - closeWrite (): void { - log.trace('%s stream %s closeWrite', this.direction, this.id) - - if (this.sinkEnded) { + async closeRead (options: AbortOptions = {}): Promise { + if (this.readStatus === 'closing' || this.readStatus === 'closed') { return } - this.closeController.abort() + this.log.trace('closing readable end of stream with starting read status "%s"', this.readStatus) - try { - // need to call this here as the sink method returns in the catch block - // when the close controller is aborted - const res = this.sendCloseWrite() + const readStatus = this.readStatus + this.readStatus = 'closing' - if (isPromise(res)) { - res.catch(err => { - log.error('error while sending close write', err) - }) - } - } catch (err) { - log.trace('%s stream %s error sending close', this.direction, this.id, err) + if (readStatus === 'ready') { + this.log.trace('ending internal source queue') + this.streamSource.end() } - this.onSinkEnd() - } + if (this.status !== 'reset' && this.status !== 'aborted') { + this.log.trace('send close read to remote') + await this.sendCloseRead(options) + } - // Close for reading and writing (local error) - abort (err: Error): void { - log.trace('%s stream %s abort', this.direction, this.id, err) - // End the source with the passed error - this.streamSource.end(err) - this.abortController.abort() - this.onSinkEnd(err) + this.log.trace('closed readable end of stream') } - // Close immediately for reading and writing (remote error) - reset (): void { - const err = new CodeError('stream reset', ERR_STREAM_RESET) - this.resetController.abort() - this.streamSource.end(err) - this.onSinkEnd(err) - } + async closeWrite (options: AbortOptions = {}): Promise { + if (this.writeStatus === 'closing' || this.writeStatus === 'closed') { + return + } - async sink (source: Source): Promise { - if (this.sinkSunk) { - throw new CodeError('sink already called on stream', ERR_DOUBLE_SINK) + this.log.trace('closing writable end of stream with starting write status "%s"', this.writeStatus) + + const writeStatus = this.writeStatus + + if (this.writeStatus === 'ready') { + this.log.trace('sink was never sunk, sink an empty array') + await this.sink([]) } - this.sinkSunk = true + this.writeStatus = 'closing' + + if (writeStatus === 'writing') { + // stop reading from the source passed to `.sink` in the microtask queue + // - this lets any data queued by the user in the current tick get read + // before we exit + await new Promise((resolve, reject) => { + queueMicrotask(() => { + this.log.trace('aborting source passed to .sink') + this.sinkController.abort() + this.sinkEnd.promise.then(resolve, reject) + }) + }) + } - if (this.sinkEnded) { - throw new CodeError('stream closed for writing', ERR_SINK_ENDED) + if (this.status !== 'reset' && this.status !== 'aborted') { + this.log.trace('send close write to remote') + await this.sendCloseWrite(options) } - const signal = anySignal([ - this.abortController.signal, - this.resetController.signal, - this.closeController.signal - ]) + this.writeStatus = 'closed' - try { - source = abortableSource(source, signal) + this.log.trace('closed writable end of stream') + } - if (this.direction === 'outbound') { // If initiator, open a new stream - const res = this.sendNewStream() + /** + * Close immediately for reading and writing and send a reset message (local + * error) + */ + abort (err: Error): void { + if (this.status === 'closed' || this.status === 'aborted' || this.status === 'reset') { + return + } - if (isPromise(res)) { - await res - } - } + this.log('abort with error', err) - for await (let data of source) { - while (data.length > 0) { - if (data.length <= this.maxDataSize) { - const res = this.sendData(data instanceof Uint8Array ? new Uint8ArrayList(data) : data) + // try to send a reset message + this.log('try to send reset to remote') + const res = this.sendReset() - if (isPromise(res)) { // eslint-disable-line max-depth - await res - } + if (isPromise(res)) { + res.catch((err) => { + this.log.error('error sending reset message', err) + }) + } - break - } - data = data instanceof Uint8Array ? new Uint8ArrayList(data) : data - const res = this.sendData(data.sublist(0, this.maxDataSize)) + this.status = 'aborted' + this.timeline.abort = Date.now() + this._closeSinkAndSource(err) + this.onAbort?.(err) + } - if (isPromise(res)) { - await res - } + /** + * Receive a reset message - close immediately for reading and writing (remote + * error) + */ + reset (): void { + if (this.status === 'closed' || this.status === 'aborted' || this.status === 'reset') { + return + } - data.consume(this.maxDataSize) - } - } - } catch (err: any) { - if (err.type === 'aborted' && err.message === 'The operation was aborted') { - if (this.closeController.signal.aborted) { - return - } + const err = new CodeError('stream reset', ERR_STREAM_RESET) - if (this.resetController.signal.aborted) { - err.message = 'stream reset' - err.code = ERR_STREAM_RESET - } + this.status = 'reset' + this._closeSinkAndSource(err) + this.onReset?.() + } - if (this.abortController.signal.aborted) { - err.message = 'stream aborted' - err.code = ERR_STREAM_ABORT - } - } + _closeSinkAndSource (err?: Error): void { + this._closeSink(err) + this._closeSource(err) + } - // Send no more data if this stream was remotely reset - if (err.code === ERR_STREAM_RESET) { - log.trace('%s stream %s reset', this.direction, this.id) - } else { - log.trace('%s stream %s error', this.direction, this.id, err) - try { - const res = this.sendReset() - - if (isPromise(res)) { - await res - } - - this.timeline.reset = Date.now() - } catch (err) { - log.trace('%s stream %s error sending reset', this.direction, this.id, err) - } - } + _closeSink (err?: Error): void { + // if the sink function is running, cause it to end + if (this.writeStatus === 'writing') { + this.log.trace('end sink source') + this.sinkController.abort() + } + this.onSinkEnd(err) + } + + _closeSource (err?: Error): void { + // if the source is not ending, end it + if (this.readStatus !== 'closing' && this.readStatus !== 'closed') { + this.log.trace('ending source with %d bytes to be read by consumer', this.streamSource.readableLength) + this.readStatus = 'closing' this.streamSource.end(err) - this.onSinkEnd(err) + } + } - throw err - } finally { - signal.clear() + /** + * The remote closed for writing so we should expect to receive no more + * messages + */ + remoteCloseWrite (): void { + if (this.readStatus === 'closing' || this.readStatus === 'closed') { + this.log('received remote close write but local source is already closed') + return } - try { - const res = this.sendCloseWrite() + this.log.trace('remote close write') + this._closeSource() + } - if (isPromise(res)) { - await res - } - } catch (err) { - log.trace('%s stream %s error sending close', this.direction, this.id, err) + /** + * The remote closed for reading so we should not send any more + * messages + */ + remoteCloseRead (): void { + if (this.writeStatus === 'closing' || this.writeStatus === 'closed') { + this.log('received remote close read but local sink is already closed') + return } - this.onSinkEnd() + this.log.trace('remote close read') + this._closeSink() + } + + /** + * The underlying muxer has closed, no more messages can be sent or will + * be received, close immediately to free up resources + */ + destroy (): void { + if (this.status === 'closed' || this.status === 'aborted' || this.status === 'reset') { + this.log('received destroy but we are already closed') + return + } + + this.log.trace('muxer destroyed') + + this._closeSinkAndSource() } /** @@ -339,27 +448,27 @@ export abstract class AbstractStream implements Stream { * Send a message to the remote muxer informing them a new stream is being * opened */ - abstract sendNewStream (): void | Promise + abstract sendNewStream (options?: AbortOptions): void | Promise /** * Send a data message to the remote muxer */ - abstract sendData (buf: Uint8ArrayList): void | Promise + abstract sendData (buf: Uint8ArrayList, options?: AbortOptions): void | Promise /** * Send a reset message to the remote muxer */ - abstract sendReset (): void | Promise + abstract sendReset (options?: AbortOptions): void | Promise /** * Send a message to the remote muxer, informing them no more data messages * will be sent by this end of the stream */ - abstract sendCloseWrite (): void | Promise + abstract sendCloseWrite (options?: AbortOptions): void | Promise /** * Send a message to the remote muxer, informing them no more data messages * will be read by this end of the stream */ - abstract sendCloseRead (): void | Promise + abstract sendCloseRead (options?: AbortOptions): void | Promise } diff --git a/packages/kad-dht/src/network.ts b/packages/kad-dht/src/network.ts index ad20c47be5..b10ad6e592 100644 --- a/packages/kad-dht/src/network.ts +++ b/packages/kad-dht/src/network.ts @@ -110,7 +110,7 @@ export class Network extends EventEmitter implements Startable { yield queryErrorEvent({ from: to, error: err }, options) } finally { if (stream != null) { - stream.close() + await stream.close() } } } @@ -140,7 +140,7 @@ export class Network extends EventEmitter implements Startable { yield queryErrorEvent({ from: to, error: err }, options) } finally { if (stream != null) { - stream.close() + await stream.close() } } } diff --git a/packages/kad-dht/src/routing-table/index.ts b/packages/kad-dht/src/routing-table/index.ts index 148d126b7b..b0bfbb90ff 100644 --- a/packages/kad-dht/src/routing-table/index.ts +++ b/packages/kad-dht/src/routing-table/index.ts @@ -217,7 +217,7 @@ export class RoutingTable extends EventEmitter implements St this.log('pinging old contact %p', oldContact.peer) const connection = await this.components.connectionManager.openConnection(oldContact.peer, options) const stream = await connection.newStream(this.protocol, options) - stream.close() + await stream.close() responded++ } catch (err: any) { if (this.running && this.kb != null) { diff --git a/packages/libp2p-daemon-server/package.json b/packages/libp2p-daemon-server/package.json index 9b7366400a..059eae1305 100644 --- a/packages/libp2p-daemon-server/package.json +++ b/packages/libp2p-daemon-server/package.json @@ -59,7 +59,7 @@ "it-drain": "^3.0.2", "it-length-prefixed": "^9.0.1", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "multiformats": "^12.0.1", "uint8arrays": "^4.0.4" }, diff --git a/packages/libp2p/.aegir.js b/packages/libp2p/.aegir.js index 21e54fde91..a0d8cbc706 100644 --- a/packages/libp2p/.aegir.js +++ b/packages/libp2p/.aegir.js @@ -24,7 +24,8 @@ export default { const peerId = await createEd25519PeerId() const libp2p = await createLibp2p({ connectionManager: { - inboundConnectionThreshold: Infinity + inboundConnectionThreshold: Infinity, + minConnections: 0 }, addresses: { listen: [ diff --git a/packages/libp2p/package.json b/packages/libp2p/package.json index ac58f9137b..587f5a5421 100644 --- a/packages/libp2p/package.json +++ b/packages/libp2p/package.json @@ -143,8 +143,8 @@ "it-merge": "^3.0.0", "it-pair": "^2.0.6", "it-parallel": "^3.0.0", - "it-pb-stream": "^4.0.1", "it-pipe": "^3.0.1", + "it-protobuf-stream": "^1.0.0", "it-stream-types": "^2.0.1", "merge-options": "^3.0.4", "multiformats": "^12.0.1", @@ -180,7 +180,7 @@ "delay": "^6.0.0", "execa": "^7.1.1", "go-libp2p": "^1.1.1", - "it-pushable": "^3.0.0", + "it-pushable": "^3.2.0", "it-to-buffer": "^4.0.1", "npm-run-all": "^4.1.5", "p-event": "^6.0.0", diff --git a/packages/libp2p/src/circuit-relay/server/index.ts b/packages/libp2p/src/circuit-relay/server/index.ts index 8c383b074a..67f74b74f8 100644 --- a/packages/libp2p/src/circuit-relay/server/index.ts +++ b/packages/libp2p/src/circuit-relay/server/index.ts @@ -4,7 +4,7 @@ import { logger } from '@libp2p/logger' import { peerIdFromBytes } from '@libp2p/peer-id' import { RecordEnvelope } from '@libp2p/peer-record' import { type Multiaddr, multiaddr } from '@multiformats/multiaddr' -import { pbStream, type ProtobufStream } from 'it-pb-stream' +import { pbStream, type ProtobufStream } from 'it-protobuf-stream' import pDefer from 'p-defer' import { MAX_CONNECTIONS } from '../../connection-manager/constants.js' import { @@ -223,7 +223,7 @@ class CircuitRelayServer extends EventEmitter implements Star ]) } catch (err: any) { log.error('error while handling hop', err) - pbstr.pb(HopMessage).write({ + await pbstr.pb(HopMessage).write({ type: HopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }) @@ -240,7 +240,7 @@ class CircuitRelayServer extends EventEmitter implements Star case HopMessage.Type.CONNECT: await this.handleConnect({ stream, request, connection }); break default: { log.error('invalid hop request type %s via peer %p', request.type, connection.remotePeer) - stream.pb(HopMessage).write({ type: HopMessage.Type.STATUS, status: Status.UNEXPECTED_MESSAGE }) + await stream.pb(HopMessage).write({ type: HopMessage.Type.STATUS, status: Status.UNEXPECTED_MESSAGE }) } } } @@ -251,20 +251,20 @@ class CircuitRelayServer extends EventEmitter implements Star if (isRelayAddr(connection.remoteAddr)) { log.error('relay reservation over circuit connection denied for peer: %p', connection.remotePeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) return } if ((await this.connectionGater.denyInboundRelayReservation?.(connection.remotePeer)) === true) { log.error('reservation for %p denied by connection gater', connection.remotePeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) return } const result = this.reservationStore.reserve(connection.remotePeer, connection.remoteAddr) if (result.status !== Status.OK) { - hopstr.write({ type: HopMessage.Type.STATUS, status: result.status }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: result.status }) return } @@ -280,7 +280,7 @@ class CircuitRelayServer extends EventEmitter implements Star }) } - hopstr.write({ + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.OK, reservation: await this.makeReservation(connection.remotePeer, BigInt(result.expire ?? 0)), @@ -325,7 +325,7 @@ class CircuitRelayServer extends EventEmitter implements Star if (isRelayAddr(connection.remoteAddr)) { log.error('relay reservation over circuit connection denied for peer: %p', connection.remotePeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) return } @@ -343,19 +343,19 @@ class CircuitRelayServer extends EventEmitter implements Star dstPeer = peerIdFromBytes(request.peer.id) } catch (err) { log.error('invalid hop connect request via peer %p %s', connection.remotePeer, err) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }) return } if (!this.reservationStore.hasReservation(dstPeer)) { log.error('hop connect denied for destination peer %p not having a reservation for %p with status %s', dstPeer, connection.remotePeer, Status.NO_RESERVATION) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.NO_RESERVATION }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.NO_RESERVATION }) return } if ((await this.connectionGater.denyOutboundRelayedConnection?.(connection.remotePeer, dstPeer)) === true) { log.error('hop connect for %p to %p denied by connection gater', connection.remotePeer, dstPeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) return } @@ -363,7 +363,7 @@ class CircuitRelayServer extends EventEmitter implements Star if (connections.length === 0) { log('hop connect denied for destination peer %p not having a connection for %p as there is no destination connection', dstPeer, connection.remotePeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.NO_RESERVATION }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.NO_RESERVATION }) return } @@ -382,11 +382,11 @@ class CircuitRelayServer extends EventEmitter implements Star if (destinationStream == null) { log.error('failed to open stream to destination peer %p', destinationConnection?.remotePeer) - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.CONNECTION_FAILED }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.CONNECTION_FAILED }) return } - hopstr.write({ type: HopMessage.Type.STATUS, status: Status.OK }) + await hopstr.write({ type: HopMessage.Type.STATUS, status: Status.OK }) const sourceStream = stream.unwrap() log('connection from %p to %p established - merging streans', connection.remotePeer, dstPeer) @@ -408,7 +408,7 @@ class CircuitRelayServer extends EventEmitter implements Star }) const pbstr = pbStream(stream) const stopstr = pbstr.pb(StopMessage) - stopstr.write(request) + await stopstr.write(request) let response try { @@ -419,7 +419,7 @@ class CircuitRelayServer extends EventEmitter implements Star if (response == null) { log.error('could not read response from %p', connection.remotePeer) - stream.close() + await stream.close() return } @@ -429,7 +429,7 @@ class CircuitRelayServer extends EventEmitter implements Star } log('stop request failed with code %d', response.status) - stream.close() + await stream.close() } get reservations (): PeerMap { diff --git a/packages/libp2p/src/circuit-relay/transport/index.ts b/packages/libp2p/src/circuit-relay/transport/index.ts index 1df1b53cc4..42f504cfc1 100644 --- a/packages/libp2p/src/circuit-relay/transport/index.ts +++ b/packages/libp2p/src/circuit-relay/transport/index.ts @@ -5,7 +5,7 @@ import { peerIdFromBytes, peerIdFromString } from '@libp2p/peer-id' import { streamToMaConnection } from '@libp2p/utils/stream-to-ma-conn' import * as mafmt from '@multiformats/mafmt' import { multiaddr } from '@multiformats/multiaddr' -import { pbStream } from 'it-pb-stream' +import { pbStream } from 'it-protobuf-stream' import { MAX_CONNECTIONS } from '../../connection-manager/constants.js' import { codes } from '../../errors.js' import { CIRCUIT_PROTO_CODE, RELAY_V2_HOP_CODEC, RELAY_V2_STOP_CODEC } from '../constants.js' @@ -85,11 +85,25 @@ export interface CircuitRelayTransportInit extends RelayStoreInit { * should be set to the same value (default: 300) */ maxOutboundStopStreams?: number + + /** + * Incoming STOP requests (e.g. when a remote peer wants to dial us via a relay) + * must finish the initial protocol negotiation within this timeout in ms + * (default: 30000) + */ + stopTimeout?: number + + /** + * When creating a reservation it must complete within this number of ms + * (default: 10000) + */ + reservationCompletionTimeout?: number } const defaults = { maxInboundStopStreams: MAX_CONNECTIONS, - maxOutboundStopStreams: MAX_CONNECTIONS + maxOutboundStopStreams: MAX_CONNECTIONS, + stopTimeout: 30000 } class CircuitRelayTransport implements Transport { @@ -104,6 +118,7 @@ class CircuitRelayTransport implements Transport { private readonly reservationStore: ReservationStore private readonly maxInboundStopStreams: number private readonly maxOutboundStopStreams?: number + private readonly stopTimeout: number private started: boolean constructor (components: CircuitRelayTransportComponents, init: CircuitRelayTransportInit) { @@ -116,6 +131,7 @@ class CircuitRelayTransport implements Transport { this.connectionGater = components.connectionGater this.maxInboundStopStreams = init.maxInboundStopStreams ?? defaults.maxInboundStopStreams this.maxOutboundStopStreams = init.maxOutboundStopStreams ?? defaults.maxOutboundStopStreams + this.stopTimeout = init.stopTimeout ?? defaults.stopTimeout if (init.discoverRelays != null && init.discoverRelays > 0) { this.discovery = new RelayDiscovery(components) @@ -148,7 +164,8 @@ class CircuitRelayTransport implements Transport { await this.registrar.handle(RELAY_V2_STOP_CODEC, (data) => { void this.onStop(data).catch(err => { - log.error(err) + log.error('error while handling STOP protocol', err) + data.stream.abort(err) }) }, { maxInboundStreams: this.maxInboundStopStreams, @@ -244,7 +261,7 @@ class CircuitRelayTransport implements Transport { try { const pbstr = pbStream(stream) const hopstr = pbstr.pb(HopMessage) - hopstr.write({ + await hopstr.write({ type: HopMessage.Type.CONNECT, peer: { id: destinationPeer.toBytes(), @@ -303,43 +320,62 @@ class CircuitRelayTransport implements Transport { * An incoming STOP request means a remote peer wants to dial us via a relay */ async onStop ({ connection, stream }: IncomingStreamData): Promise { - const pbstr = pbStream(stream) - const request = await pbstr.readPB(StopMessage) - log('received circuit v2 stop protocol request from %s', connection.remotePeer) + const signal = AbortSignal.timeout(this.stopTimeout) + const pbstr = pbStream(stream).pb(StopMessage) + const request = await pbstr.read({ + signal + }) + + log('new circuit relay v2 stop stream from %p with type %s', connection.remotePeer, request.type) if (request?.type === undefined) { + log.error('type was missing from circuit v2 stop protocol request from %s', connection.remotePeer) + await pbstr.write({ type: StopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }, { + signal + }) + await stream.close() return } - const stopstr = pbstr.pb(StopMessage) - log('new circuit relay v2 stop stream from %p', connection.remotePeer) - // Validate the STOP request has the required input if (request.type !== StopMessage.Type.CONNECT) { log.error('invalid stop connect request via peer %p', connection.remotePeer) - stopstr.write({ type: StopMessage.Type.STATUS, status: Status.UNEXPECTED_MESSAGE }) + await pbstr.write({ type: StopMessage.Type.STATUS, status: Status.UNEXPECTED_MESSAGE }, { + signal + }) + await stream.close() return } if (!isValidStop(request)) { log.error('invalid stop connect request via peer %p', connection.remotePeer) - stopstr.write({ type: StopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }) + await pbstr.write({ type: StopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }, { + signal + }) + await stream.close() return } const remotePeerId = peerIdFromBytes(request.peer.id) if ((await this.connectionGater.denyInboundRelayedConnection?.(connection.remotePeer, remotePeerId)) === true) { - stopstr.write({ type: StopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }) + log.error('connection gater denied inbound relayed connection from %p', connection.remotePeer) + await pbstr.write({ type: StopMessage.Type.STATUS, status: Status.PERMISSION_DENIED }, { + signal + }) + await stream.close() return } - stopstr.write({ type: StopMessage.Type.STATUS, status: Status.OK }) + log.trace('sending success response to %p', connection.remotePeer) + await pbstr.write({ type: StopMessage.Type.STATUS, status: Status.OK }, { + signal + }) const remoteAddr = connection.remoteAddr.encapsulate(`/p2p-circuit/p2p/${remotePeerId.toString()}`) const localAddr = this.addressManager.getAddresses()[0] const maConn = streamToMaConnection({ - stream: pbstr.unwrap(), + stream: pbstr.unwrap().unwrap(), remoteAddr, localAddr }) diff --git a/packages/libp2p/src/circuit-relay/transport/listener.ts b/packages/libp2p/src/circuit-relay/transport/listener.ts index aa10997f6d..dea961b5fd 100644 --- a/packages/libp2p/src/circuit-relay/transport/listener.ts +++ b/packages/libp2p/src/circuit-relay/transport/listener.ts @@ -95,6 +95,8 @@ class CircuitRelayTransportListener extends EventEmitter impleme #removeRelayPeer (peerId: PeerId): void { const had = this.listeningAddrs.has(peerId) + log('relay peer removed %p - had reservation', peerId, had) + this.listeningAddrs.delete(peerId) if (had) { diff --git a/packages/libp2p/src/circuit-relay/transport/reservation-store.ts b/packages/libp2p/src/circuit-relay/transport/reservation-store.ts index c01e616da7..027fc1ea05 100644 --- a/packages/libp2p/src/circuit-relay/transport/reservation-store.ts +++ b/packages/libp2p/src/circuit-relay/transport/reservation-store.ts @@ -2,13 +2,13 @@ import { EventEmitter } from '@libp2p/interface/events' import { logger } from '@libp2p/logger' import { PeerMap } from '@libp2p/peer-collections' import { multiaddr } from '@multiformats/multiaddr' -import { pbStream } from 'it-pb-stream' +import { pbStream } from 'it-protobuf-stream' import { PeerJobQueue } from '../../utils/peer-job-queue.js' import { DEFAULT_RESERVATION_CONCURRENCY, RELAY_TAG, RELAY_V2_HOP_CODEC } from '../constants.js' import { HopMessage, Status } from '../pb/index.js' import { getExpirationMilliseconds } from '../utils.js' import type { Reservation } from '../pb/index.js' -import type { Libp2pEvents } from '@libp2p/interface' +import type { Libp2pEvents, AbortOptions } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' import type { PeerId } from '@libp2p/interface/peer-id' import type { PeerStore } from '@libp2p/interface/peer-store' @@ -55,6 +55,12 @@ export interface RelayStoreInit { * Limit the number of potential relays we will dial (default: 100) */ maxReservationQueueLength?: number + + /** + * When creating a reservation it must complete within this number of ms + * (default: 5000) + */ + reservationCompletionTimeout?: number } export type RelayType = 'discovered' | 'configured' @@ -80,6 +86,7 @@ export class ReservationStore extends EventEmitter imple private readonly reservations: PeerMap private readonly maxDiscoveredRelays: number private readonly maxReservationQueueLength: number + private readonly reservationCompletionTimeout: number private started: boolean constructor (components: RelayStoreComponents, init?: RelayStoreInit) { @@ -93,6 +100,7 @@ export class ReservationStore extends EventEmitter imple this.reservations = new PeerMap() this.maxDiscoveredRelays = init?.discoverRelays ?? 0 this.maxReservationQueueLength = init?.maxReservationQueueLength ?? 100 + this.reservationCompletionTimeout = init?.reservationCompletionTimeout ?? 10000 this.started = false // ensure we don't listen on multiple relays simultaneously @@ -117,12 +125,12 @@ export class ReservationStore extends EventEmitter imple } async stop (): Promise { + this.reserveQueue.clear() this.reservations.forEach(({ timeout }) => { clearTimeout(timeout) }) this.reservations.clear() - - this.started = true + this.started = false } /** @@ -175,14 +183,20 @@ export class ReservationStore extends EventEmitter imple return } - const connection = await this.connectionManager.openConnection(peerId) + const signal = AbortSignal.timeout(this.reservationCompletionTimeout) + + const connection = await this.connectionManager.openConnection(peerId, { + signal + }) if (connection.remoteAddr.protoNames().includes('p2p-circuit')) { log('not creating reservation over relayed connection') return } - const reservation = await this.#createReservation(connection) + const reservation = await this.#createReservation(connection, { + signal + }) log('created reservation on relay peer %p', peerId) @@ -220,6 +234,13 @@ export class ReservationStore extends EventEmitter imple } catch (err) { log.error('could not reserve slot on %p', peerId, err) + // cancel the renewal timeout if it's been set + const reservation = this.reservations.get(peerId) + + if (reservation != null) { + clearTimeout(reservation.timeout) + } + // if listening failed, remove the reservation this.reservations.delete(peerId) } @@ -236,22 +257,24 @@ export class ReservationStore extends EventEmitter imple return this.reservations.get(peerId)?.reservation } - async #createReservation (connection: Connection): Promise { + async #createReservation (connection: Connection, options: AbortOptions): Promise { + options.signal?.throwIfAborted() + log('requesting reservation from %p', connection.remotePeer) const stream = await connection.newStream(RELAY_V2_HOP_CODEC) const pbstr = pbStream(stream) const hopstr = pbstr.pb(HopMessage) - hopstr.write({ type: HopMessage.Type.RESERVE }) + await hopstr.write({ type: HopMessage.Type.RESERVE }, options) let response: HopMessage try { - response = await hopstr.read() + response = await hopstr.read(options) } catch (err: any) { log.error('error parsing reserve message response from %p because', connection.remotePeer, err) throw err } finally { - stream.close() + await stream.close() } if (response.status === Status.OK && (response.reservation != null)) { diff --git a/packages/libp2p/src/connection-manager/auto-dial.ts b/packages/libp2p/src/connection-manager/auto-dial.ts index a9a87ed408..c7357652dd 100644 --- a/packages/libp2p/src/connection-manager/auto-dial.ts +++ b/packages/libp2p/src/connection-manager/auto-dial.ts @@ -113,7 +113,9 @@ export class AutoDial implements Startable { // Already has enough connections if (numConnections >= this.minConnections) { - log.trace('have enough connections %d/%d', numConnections, this.minConnections) + if (this.minConnections > 0) { + log.trace('have enough connections %d/%d', numConnections, this.minConnections) + } return } diff --git a/packages/libp2p/src/connection-manager/dial-queue.ts b/packages/libp2p/src/connection-manager/dial-queue.ts index 53658b4d6e..fb540bb33e 100644 --- a/packages/libp2p/src/connection-manager/dial-queue.ts +++ b/packages/libp2p/src/connection-manager/dial-queue.ts @@ -16,7 +16,7 @@ import { MAX_PEER_ADDRS_TO_DIAL } from './constants.js' import { combineSignals, resolveMultiaddrs } from './utils.js' -import type { AddressSorter, AbortOptions } from '@libp2p/interface' +import type { AddressSorter, AbortOptions, PendingDial } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' import type { ConnectionGater } from '@libp2p/interface/connection-gater' import type { Metric, Metrics } from '@libp2p/interface/metrics' @@ -26,16 +26,6 @@ import type { TransportManager } from '@libp2p/interface-internal/transport-mana const log = logger('libp2p:connection-manager:dial-queue') -export type PendingDialStatus = 'queued' | 'active' | 'error' | 'success' - -export interface PendingDial { - id: string - status: PendingDialStatus - peerId?: PeerId - multiaddrs: Multiaddr[] - promise: Promise -} - export interface PendingDialTarget { resolve: (value: any) => void reject: (err: Error) => void @@ -45,6 +35,10 @@ export interface DialOptions extends AbortOptions { priority?: number } +interface PendingDialInternal extends PendingDial { + promise: Promise +} + interface DialerInit { addressSorter?: AddressSorter maxParallelDials?: number @@ -74,7 +68,7 @@ interface DialQueueComponents { } export class DialQueue { - public pendingDials: PendingDial[] + public pendingDials: PendingDialInternal[] public queue: PQueue private readonly peerId: PeerId private readonly peerStore: PeerStore @@ -218,7 +212,7 @@ export class DialQueue { log('creating dial target for', addrsToDial.map(({ multiaddr }) => multiaddr.toString())) // @ts-expect-error .promise property is set below - const pendingDial: PendingDial = { + const pendingDial: PendingDialInternal = { id: randomId(), status: 'queued', peerId, @@ -394,7 +388,7 @@ export class DialQueue { return sortedGatedAddrs } - private async performDial (pendingDial: PendingDial, options: DialOptions = {}): Promise { + private async performDial (pendingDial: PendingDialInternal, options: DialOptions = {}): Promise { const dialAbortControllers: Array<(AbortController | undefined)> = pendingDial.multiaddrs.map(() => new AbortController()) try { diff --git a/packages/libp2p/src/connection-manager/index.ts b/packages/libp2p/src/connection-manager/index.ts index 96cafb3eac..a73ff85a75 100644 --- a/packages/libp2p/src/connection-manager/index.ts +++ b/packages/libp2p/src/connection-manager/index.ts @@ -488,6 +488,8 @@ export class DefaultConnectionManager implements ConnectionManager, Startable { throw new CodeError('Not started', codes.ERR_NODE_NOT_STARTED) } + options.signal?.throwIfAborted() + const { peerId } = getPeerAddress(peerIdOrMultiaddr) if (peerId != null) { @@ -530,12 +532,16 @@ export class DefaultConnectionManager implements ConnectionManager, Startable { return connection } - async closeConnections (peerId: PeerId): Promise { + async closeConnections (peerId: PeerId, options: AbortOptions = {}): Promise { const connections = this.connections.get(peerId) ?? [] await Promise.all( connections.map(async connection => { - await connection.close() + try { + await connection.close(options) + } catch (err: any) { + connection.abort(err) + } }) ) } diff --git a/packages/libp2p/src/connection/index.ts b/packages/libp2p/src/connection/index.ts index 30bd308eec..94241bd49c 100644 --- a/packages/libp2p/src/connection/index.ts +++ b/packages/libp2p/src/connection/index.ts @@ -1,21 +1,23 @@ -import { type Direction, symbol, type Connection, type Stream, type ConnectionTimeline } from '@libp2p/interface/connection' -import { OPEN, CLOSING, CLOSED } from '@libp2p/interface/connection/status' +import { setMaxListeners } from 'events' +import { type Direction, symbol, type Connection, type Stream, type ConnectionTimeline, type ConnectionStatus } from '@libp2p/interface/connection' import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' import type { AbortOptions } from '@libp2p/interface' -import type * as Status from '@libp2p/interface/connection/status' import type { PeerId } from '@libp2p/interface/peer-id' import type { Multiaddr } from '@multiformats/multiaddr' const log = logger('libp2p:connection') +const CLOSE_TIMEOUT = 500 + interface ConnectionInit { remoteAddr: Multiaddr remotePeer: PeerId newStream: (protocols: string[], options?: AbortOptions) => Promise - close: () => Promise + close: (options?: AbortOptions) => Promise + abort: (err: Error) => void getStreams: () => Stream[] - status: keyof typeof Status + status: ConnectionStatus direction: Direction timeline: ConnectionTimeline multiplexer?: string @@ -46,7 +48,7 @@ export class ConnectionImpl implements Connection { public timeline: ConnectionTimeline public multiplexer?: string public encryption?: string - public status: keyof typeof Status + public status: ConnectionStatus /** * User provided tags @@ -62,36 +64,36 @@ export class ConnectionImpl implements Connection { /** * Reference to the close function of the raw connection */ - private readonly _close: () => Promise + private readonly _close: (options?: AbortOptions) => Promise + + private readonly _abort: (err: Error) => void /** * Reference to the getStreams function of the muxer */ private readonly _getStreams: () => Stream[] - private _closing: boolean - /** * An implementation of the js-libp2p connection. * Any libp2p transport should use an upgrader to return this connection. */ constructor (init: ConnectionInit) { - const { remoteAddr, remotePeer, newStream, close, getStreams } = init + const { remoteAddr, remotePeer, newStream, close, abort, getStreams } = init this.id = `${(parseInt(String(Math.random() * 1e9))).toString(36)}${Date.now()}` this.remoteAddr = remoteAddr this.remotePeer = remotePeer this.direction = init.direction - this.status = OPEN + this.status = 'open' this.timeline = init.timeline this.multiplexer = init.multiplexer this.encryption = init.encryption this._newStream = newStream this._close = close + this._abort = abort this._getStreams = getStreams this.tags = [] - this._closing = false } readonly [Symbol.toStringTag] = 'Connection' @@ -109,11 +111,11 @@ export class ConnectionImpl implements Connection { * Create a new stream from this connection */ async newStream (protocols: string | string[], options?: AbortOptions): Promise { - if (this.status === CLOSING) { + if (this.status === 'closing') { throw new CodeError('the connection is being closed', 'ERR_CONNECTION_BEING_CLOSED') } - if (this.status === CLOSED) { + if (this.status === 'closed') { throw new CodeError('the connection is closed', 'ERR_CONNECTION_CLOSED') } @@ -145,27 +147,52 @@ export class ConnectionImpl implements Connection { /** * Close the connection */ - async close (): Promise { - if (this.status === CLOSED || this._closing) { + async close (options: AbortOptions = {}): Promise { + if (this.status === 'closed' || this.status === 'closing') { return } - this.status = CLOSING + log('closing connection to %a', this.remoteAddr) + + this.status = 'closing' + + options.signal = options?.signal ?? AbortSignal.timeout(CLOSE_TIMEOUT) - // close all streams - this can throw if we're not multiplexed try { - this.streams.forEach(s => { s.close() }) - } catch (err) { - log.error(err) + // fails on node < 15.4 + setMaxListeners?.(Infinity, options.signal) + } catch { } + + try { + // close all streams gracefully - this can throw if we're not multiplexed + await Promise.all( + this.streams.map(async s => s.close(options)) + ) + + // Close raw connection + await this._close(options) + + this.timeline.close = Date.now() + this.status = 'closed' + } catch (err: any) { + log.error('error encountered during graceful close of connection to %a', this.remoteAddr, err) + this.abort(err) } + } + + abort (err: Error): void { + log.error('aborting connection to %a due to error', this.remoteAddr, err) + + this.status = 'closing' + this.streams.forEach(s => { s.abort(err) }) + + log.error('all streams aborted', this.streams.length) - // Close raw connection - this._closing = true - await this._close() - this._closing = false + // Abort raw connection + this._abort(err) this.timeline.close = Date.now() - this.status = CLOSED + this.status = 'closed' } } diff --git a/packages/libp2p/src/fetch/index.ts b/packages/libp2p/src/fetch/index.ts index 8084fadc82..f72e1e7b4a 100644 --- a/packages/libp2p/src/fetch/index.ts +++ b/packages/libp2p/src/fetch/index.ts @@ -109,12 +109,12 @@ class DefaultFetchService implements Startable, FetchService { async start (): Promise { await this.components.registrar.handle(this.protocol, (data) => { void this.handleMessage(data) + .then(async () => { + await data.stream.close() + }) .catch(err => { log.error(err) }) - .finally(() => { - data.stream.close() - }) }, { maxInboundStreams: this.init.maxInboundStreams, maxOutboundStreams: this.init.maxOutboundStreams @@ -201,7 +201,7 @@ class DefaultFetchService implements Startable, FetchService { return result ?? null } finally { if (stream != null) { - stream.close() + await stream.close() } } } diff --git a/packages/libp2p/src/identify/identify.ts b/packages/libp2p/src/identify/identify.ts index a3bdf54936..351074786b 100644 --- a/packages/libp2p/src/identify/identify.ts +++ b/packages/libp2p/src/identify/identify.ts @@ -4,12 +4,7 @@ import { logger } from '@libp2p/logger' import { peerIdFromKeys } from '@libp2p/peer-id' import { RecordEnvelope, PeerRecord } from '@libp2p/peer-record' import { type Multiaddr, multiaddr, protocols } from '@multiformats/multiaddr' -import { abortableDuplex } from 'abortable-iterator' -import { anySignal } from 'any-signal' -import first from 'it-first' -import * as lp from 'it-length-prefixed' -import { pbStream } from 'it-pb-stream' -import { pipe } from 'it-pipe' +import { pbStream } from 'it-protobuf-stream' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { toString as uint8ArrayToString } from 'uint8arrays/to-string' import { isNode, isBrowser, isWebWorker, isElectronMain, isElectronRenderer, isReactNative } from 'wherearewe' @@ -197,26 +192,27 @@ export class DefaultIdentifyService implements Startable, IdentifyService { signal }) - // make stream abortable - const source = abortableDuplex(stream, signal) - - await source.sink(pipe( - [Identify.encode({ - listenAddrs: listenAddresses.map(ma => ma.bytes), - signedPeerRecord: signedPeerRecord.marshal(), - protocols: supportedProtocols, - agentVersion, - protocolVersion - })], - (source) => lp.encode(source) - )) + const pb = pbStream(stream, { + maxDataLength: this.maxIdentifyMessageSize ?? MAX_IDENTIFY_MESSAGE_SIZE + }).pb(Identify) + + await pb.write({ + listenAddrs: listenAddresses.map(ma => ma.bytes), + signedPeerRecord: signedPeerRecord.marshal(), + protocols: supportedProtocols, + agentVersion, + protocolVersion + }, { + signal + }) + + await stream.close({ + signal + }) } catch (err: any) { // Just log errors log.error('could not push identify update to peer', err) - } finally { - if (stream != null) { - stream.close() - } + stream?.abort(err) } }) @@ -258,44 +254,24 @@ export class DefaultIdentifyService implements Startable, IdentifyService { async _identify (connection: Connection, options: AbortOptions = {}): Promise { let stream: Stream | undefined - const signal = anySignal([AbortSignal.timeout(this.timeout), options?.signal]) + options.signal = options.signal ?? AbortSignal.timeout(this.timeout) try { - // fails on node < 15.4 - setMaxListeners?.(Infinity, signal) - } catch {} + stream = await connection.newStream([this.identifyProtocolStr], options) - try { - stream = await connection.newStream([this.identifyProtocolStr], { - signal - }) + const pb = pbStream(stream, { + maxDataLength: this.maxIdentifyMessageSize ?? MAX_IDENTIFY_MESSAGE_SIZE + }).pb(Identify) - // make stream abortable - const source = abortableDuplex(stream, signal) + const message = await pb.read(options) - const data = await pipe( - [], - source, - (source) => lp.decode(source, { - maxDataLength: this.maxIdentifyMessageSize ?? MAX_IDENTIFY_MESSAGE_SIZE - }), - async (source) => first(source) - ) + await stream.close(options) - if (data == null) { - throw new CodeError('No data could be retrieved', codes.ERR_CONNECTION_ENDED) - } - - try { - return Identify.decode(data) - } catch (err: any) { - throw new CodeError(String(err), codes.ERR_INVALID_MESSAGE) - } - } finally { - if (stream != null) { - stream.close() - } - signal.clear() + return message + } catch (err: any) { + log.error('error while reading identify message', err) + stream?.abort(err) + throw err } } @@ -381,7 +357,9 @@ export class DefaultIdentifyService implements Startable, IdentifyService { signedPeerRecord = envelope.marshal().subarray() } - const message = Identify.encode({ + const pb = pbStream(stream).pb(Identify) + + await pb.write({ protocolVersion: this.host.protocolVersion, agentVersion: this.host.agentVersion, publicKey, @@ -389,17 +367,16 @@ export class DefaultIdentifyService implements Startable, IdentifyService { signedPeerRecord, observedAddr: connection.remoteAddr.bytes, protocols: peerData.protocols + }, { + signal }) - // make stream abortable - const source = abortableDuplex(stream, signal) - - const msgWithLenPrefix = pipe([message], (source) => lp.encode(source)) - await source.sink(msgWithLenPrefix) + await stream.close({ + signal + }) } catch (err: any) { log.error('could not respond to identify request', err) - } finally { - stream.close() + stream.abort(err) } } @@ -414,19 +391,22 @@ export class DefaultIdentifyService implements Startable, IdentifyService { throw new Error('received push from ourselves?') } - // make stream abortable - const source = abortableDuplex(stream, AbortSignal.timeout(this.timeout)) - const pb = pbStream(source, { + const options = { + signal: AbortSignal.timeout(this.timeout) + } + + const pb = pbStream(stream, { maxDataLength: this.maxIdentifyMessageSize ?? MAX_IDENTIFY_MESSAGE_SIZE - }) - const message = await pb.readPB(Identify) + }).pb(Identify) + + const message = await pb.read(options) + await stream.close(options) await this.#consumeIdentifyMessage(connection.remotePeer, message) } catch (err: any) { log.error('received invalid message', err) + stream.abort(err) return - } finally { - stream.close() } log('handled push from %p', connection.remotePeer) diff --git a/packages/libp2p/src/libp2p.ts b/packages/libp2p/src/libp2p.ts index 79fa645b79..206c0c45d7 100644 --- a/packages/libp2p/src/libp2p.ts +++ b/packages/libp2p/src/libp2p.ts @@ -307,12 +307,12 @@ export class Libp2pNode> extends return this.components.registrar.getProtocols() } - async hangUp (peer: PeerId | Multiaddr): Promise { + async hangUp (peer: PeerId | Multiaddr, options: AbortOptions = {}): Promise { if (isMultiaddr(peer)) { peer = peerIdFromString(peer.getPeerId() ?? '') } - await this.components.connectionManager.closeConnections(peer) + await this.components.connectionManager.closeConnections(peer, options) } /** diff --git a/packages/libp2p/src/ping/index.ts b/packages/libp2p/src/ping/index.ts index 3903f9f0ae..d189d41202 100644 --- a/packages/libp2p/src/ping/index.ts +++ b/packages/libp2p/src/ping/index.ts @@ -1,9 +1,7 @@ -import { setMaxListeners } from 'events' import { randomBytes } from '@libp2p/crypto' import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' import { abortableDuplex } from 'abortable-iterator' -import { anySignal } from 'any-signal' import first from 'it-first' import { pipe } from 'it-pipe' import { equals as uint8ArrayEquals } from 'uint8arrays/equals' @@ -77,11 +75,19 @@ class DefaultPingService implements Startable, PingService { * A handler to register with Libp2p to process ping messages */ handleMessage (data: IncomingStreamData): void { + log('incoming ping from %p', data.connection.remotePeer) + const { stream } = data + const start = Date.now() void pipe(stream, stream) .catch(err => { - log.error(err) + log.error('incoming ping from %p failed with error', data.connection.remotePeer, err) + }) + .finally(() => { + const ms = Date.now() - start + + log('incoming ping from %p complete in %dms', data.connection.remotePeer, ms) }) } @@ -92,45 +98,50 @@ class DefaultPingService implements Startable, PingService { * @returns {Promise} */ async ping (peer: PeerId | Multiaddr | Multiaddr[], options: AbortOptions = {}): Promise { - log('dialing %s to %p', this.protocol, peer) + log('pinging %p', peer) const start = Date.now() const data = randomBytes(PING_LENGTH) const connection = await this.components.connectionManager.openConnection(peer, options) let stream: Stream | undefined - const signal = anySignal([AbortSignal.timeout(this.timeout), options?.signal]) + options.signal = options.signal ?? AbortSignal.timeout(this.timeout) try { - // fails on node < 15.4 - setMaxListeners?.(Infinity, signal) - } catch {} - - try { - stream = await connection.newStream([this.protocol], { - signal - }) + stream = await connection.newStream([this.protocol], options) // make stream abortable - const source = abortableDuplex(stream, signal) + const source = abortableDuplex(stream, options.signal) const result = await pipe( [data], source, async (source) => first(source) ) - const end = Date.now() - if (result == null || !uint8ArrayEquals(data, result.subarray())) { - throw new CodeError('Received wrong ping ack', codes.ERR_WRONG_PING_ACK) + const ms = Date.now() - start + + if (result == null) { + throw new CodeError(`Did not receive a ping ack after ${ms}ms`, codes.ERR_WRONG_PING_ACK) + } + + if (!uint8ArrayEquals(data, result.subarray())) { + throw new CodeError(`Received wrong ping ack after ${ms}ms`, codes.ERR_WRONG_PING_ACK) } - return end - start + log('ping %p complete in %dms', connection.remotePeer, ms) + + return ms + } catch (err: any) { + log.error('error while pinging %p', connection.remotePeer, err) + + stream?.abort(err) + + throw err } finally { if (stream != null) { - stream.close() + await stream.close() } - signal.clear() } } } diff --git a/packages/libp2p/src/transport-manager.ts b/packages/libp2p/src/transport-manager.ts index a5b1742ca6..070503764a 100644 --- a/packages/libp2p/src/transport-manager.ts +++ b/packages/libp2p/src/transport-manager.ts @@ -177,6 +177,10 @@ export class DefaultTransportManager implements TransportManager, Startable { * Starts listeners for each listen Multiaddr */ async listen (addrs: Multiaddr[]): Promise { + if (!this.isStarted()) { + throw new CodeError('Not started', codes.ERR_NODE_NOT_STARTED) + } + if (addrs == null || addrs.length === 0) { log('no addresses were provided for listening, this node is dial only') return diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index 20b0402e95..f40e1cdd52 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -416,11 +416,11 @@ export class DefaultUpgrader implements Upgrader { this._onStream({ connection, stream: muxedStream, protocol }) }) - .catch(err => { + .catch(async err => { log.error(err) if (muxedStream.timeline.close == null) { - muxedStream.close() + await muxedStream.close() } }) }, @@ -481,7 +481,7 @@ export class DefaultUpgrader implements Upgrader { log.error('could not create new stream', err) if (muxedStream.timeline.close == null) { - muxedStream.close() + muxedStream.abort(err) } if (err.code != null) { @@ -508,7 +508,7 @@ export class DefaultUpgrader implements Upgrader { // Wait for close to finish before notifying of the closure (async () => { try { - if (connection.status === 'OPEN') { + if (connection.status === 'open') { await connection.close() } } catch (err: any) { @@ -536,18 +536,25 @@ export class DefaultUpgrader implements Upgrader { connection = createConnection({ remoteAddr: maConn.remoteAddr, remotePeer, - status: 'OPEN', + status: 'open', direction, timeline: maConn.timeline, multiplexer: muxer?.protocol, encryption: cryptoProtocol, newStream: newStream ?? errConnectionNotMultiplexed, - getStreams: () => { if (muxer != null) { return muxer.streams } else { return errConnectionNotMultiplexed() } }, - close: async () => { - await maConn.close() - // Ensure remaining streams are closed + getStreams: () => { if (muxer != null) { return muxer.streams } else { return [] } }, + close: async (options?: AbortOptions) => { + await maConn.close(options) + // Ensure remaining streams are closed gracefully + if (muxer != null) { + await muxer.close(options) + } + }, + abort: (err) => { + maConn.abort(err) + // Ensure remaining streams are aborted if (muxer != null) { - muxer.close() + muxer.abort(err) } } }) diff --git a/packages/libp2p/test/circuit-relay/hop.spec.ts b/packages/libp2p/test/circuit-relay/hop.spec.ts index 5c7cd47788..7c3405197c 100644 --- a/packages/libp2p/test/circuit-relay/hop.spec.ts +++ b/packages/libp2p/test/circuit-relay/hop.spec.ts @@ -8,7 +8,7 @@ import { PeerMap } from '@libp2p/peer-collections' import { createEd25519PeerId } from '@libp2p/peer-id-factory' import { type Multiaddr, multiaddr } from '@multiformats/multiaddr' import { expect } from 'aegir/chai' -import { type MessageStream, pbStream } from 'it-pb-stream' +import { type MessageStream, pbStream } from 'it-protobuf-stream' import Sinon from 'sinon' import { type StubbedInstance, stubInterface } from 'sinon-ts' import { DEFAULT_MAX_RESERVATION_STORE_SIZE, RELAY_SOURCE_TAG, RELAY_V2_HOP_CODEC } from '../../src/circuit-relay/constants.js' @@ -149,7 +149,7 @@ describe('circuit-relay hop protocol', function () { const clientPbStream = await openStream(client, relay, RELAY_V2_HOP_CODEC) // send reserve message - clientPbStream.write({ + await clientPbStream.write({ type: HopMessage.Type.RESERVE }) @@ -163,7 +163,7 @@ describe('circuit-relay hop protocol', function () { const clientPbStream = await openStream(client, relay, RELAY_V2_HOP_CODEC) // send reserve message - clientPbStream.write({ + await clientPbStream.write({ type: HopMessage.Type.CONNECT, peer: { id: target.peerId.toBytes(), @@ -206,7 +206,7 @@ describe('circuit-relay hop protocol', function () { const clientPbStream = await openStream(clientNode, relayNode, RELAY_V2_HOP_CODEC) // wrong initial message - clientPbStream.write({ + await clientPbStream.write({ type: HopMessage.Type.STATUS, status: Status.MALFORMED_MESSAGE }) @@ -322,7 +322,7 @@ describe('circuit-relay hop protocol', function () { await expect(makeReservation(targetNode, relayNode)).to.eventually.have.nested.property('response.status', Status.OK) const clientPbStream = await openStream(clientNode, relayNode, RELAY_V2_HOP_CODEC) - clientPbStream.write({ + await clientPbStream.write({ type: HopMessage.Type.CONNECT, // @ts-expect-error {} is missing the following properties from peer: id, addrs peer: {} diff --git a/packages/libp2p/test/circuit-relay/relay.node.ts b/packages/libp2p/test/circuit-relay/relay.node.ts index e91dd40657..ae2786fea9 100644 --- a/packages/libp2p/test/circuit-relay/relay.node.ts +++ b/packages/libp2p/test/circuit-relay/relay.node.ts @@ -8,7 +8,7 @@ import { Circuit } from '@multiformats/mafmt' import { multiaddr } from '@multiformats/multiaddr' import { expect } from 'aegir/chai' import delay from 'delay' -import { pbStream } from 'it-pb-stream' +import { pbStream } from 'it-protobuf-stream' import defer from 'p-defer' import pWaitFor from 'p-wait-for' import sinon from 'sinon' @@ -19,7 +19,7 @@ import { HopMessage, Status } from '../../src/circuit-relay/pb/index.js' import { identifyService } from '../../src/identify/index.js' import { createLibp2p } from '../../src/index.js' import { plaintext } from '../../src/insecure/index.js' -import { discoveredRelayConfig, getRelayAddress, hasRelay, usingAsRelay } from './utils.js' +import { discoveredRelayConfig, doesNotHaveRelay, getRelayAddress, hasRelay, usingAsRelay } from './utils.js' import type { Libp2p } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' @@ -34,6 +34,9 @@ describe('circuit-relay', () => { // create 1 node and 3 relays [local, relay1, relay2, relay3] = await Promise.all([ createLibp2p({ + connectionManager: { + minConnections: 0 + }, addresses: { listen: ['/ip4/127.0.0.1/tcp/0'] }, @@ -55,6 +58,9 @@ describe('circuit-relay', () => { } }), createLibp2p({ + connectionManager: { + minConnections: 0 + }, addresses: { listen: ['/ip4/127.0.0.1/tcp/0'] }, @@ -77,6 +83,9 @@ describe('circuit-relay', () => { } }), createLibp2p({ + connectionManager: { + minConnections: 0 + }, addresses: { listen: ['/ip4/127.0.0.1/tcp/0'] }, @@ -99,6 +108,9 @@ describe('circuit-relay', () => { } }), createLibp2p({ + connectionManager: { + minConnections: 0 + }, addresses: { listen: ['/ip4/127.0.0.1/tcp/0'] }, @@ -258,22 +270,24 @@ describe('circuit-relay', () => { // wait for identify for newly dialled peer await discoveredRelayConfig(relay1, relay3) - // disconnect not used listen relay - await relay1.hangUp(relay3.peerId) - - // Stub dial + // stub dial, make sure we can't reconnect // @ts-expect-error private field sinon.stub(relay1.components.connectionManager, 'openConnection').callsFake(async () => { deferred.resolve() return Promise.reject(new Error('failed to dial')) }) - // Remove peer used as relay from peerStore and disconnect it - await relay1.hangUp(relay2.peerId) - await relay1.peerStore.delete(relay2.peerId) + await Promise.all([ + // disconnect not used listen relay + relay1.hangUp(relay3.peerId), + + // disconnect from relay + relay1.hangUp(relay2.peerId) + ]) + expect(relay1.getConnections()).to.be.empty() - // Wait for failed dial + // wait for failed dial await deferred.promise }) @@ -316,24 +330,11 @@ describe('circuit-relay', () => { // wait for peer added as listen relay await usingAsRelay(local, relay1) - // set up listener for address change - const deferred = defer() - - local.addEventListener('self:peer:update', ({ detail }) => { - const hasNoCircuitRelayAddress = detail.peer.addresses - .map(({ multiaddr }) => multiaddr) - .find(ma => Circuit.matches(ma)) == null - - if (hasNoCircuitRelayAddress) { - deferred.resolve() - } - }) - // shut down the relay await relay1.stop() // should no longer have a circuit address - await deferred.promise + await doesNotHaveRelay(local) }) }) @@ -550,7 +551,7 @@ describe('circuit-relay', () => { // we should still be connected to the relay const conns = local.getConnections(relay1.peerId) expect(conns).to.have.lengthOf(1) - expect(conns).to.have.nested.property('[0].status', 'OPEN') + expect(conns).to.have.nested.property('[0].status', 'open') }) it('dialer should close hop stream on hop failure', async () => { @@ -566,7 +567,7 @@ describe('circuit-relay', () => { // we should still be connected to the relay const conns = local.getConnections(relay1.peerId) expect(conns).to.have.lengthOf(1) - expect(conns).to.have.nested.property('[0].status', 'OPEN') + expect(conns).to.have.nested.property('[0].status', 'open') // we should not have any streams with the hop codec const streams = local.getConnections(relay1.peerId) @@ -591,7 +592,7 @@ describe('circuit-relay', () => { // we should still be connected to the relay const remoteConns = local.getConnections(relay1.peerId) expect(remoteConns).to.have.lengthOf(1) - expect(remoteConns).to.have.nested.property('[0].status', 'OPEN') + expect(remoteConns).to.have.nested.property('[0].status', 'open') }) it('should fail to dial remote over relay over relay', async () => { @@ -627,7 +628,7 @@ describe('circuit-relay', () => { const hopStream = pbStream(stream).pb(HopMessage) - hopStream.write({ + await hopStream.write({ type: HopMessage.Type.CONNECT, peer: { id: remote.peerId.toBytes(), diff --git a/packages/libp2p/test/circuit-relay/stop.spec.ts b/packages/libp2p/test/circuit-relay/stop.spec.ts index fc115ced80..872f01d104 100644 --- a/packages/libp2p/test/circuit-relay/stop.spec.ts +++ b/packages/libp2p/test/circuit-relay/stop.spec.ts @@ -5,12 +5,14 @@ import { isStartable } from '@libp2p/interface/startable' import { mockStream } from '@libp2p/interface-compliance-tests/mocks' import { createEd25519PeerId } from '@libp2p/peer-id-factory' import { expect } from 'aegir/chai' +import delay from 'delay' import { duplexPair } from 'it-pair/duplex' -import { pbStream, type MessageStream } from 'it-pb-stream' +import { pbStream, type MessageStream } from 'it-protobuf-stream' +import Sinon from 'sinon' import { stubInterface } from 'sinon-ts' import { circuitRelayTransport } from '../../src/circuit-relay/index.js' import { Status, StopMessage } from '../../src/circuit-relay/pb/index.js' -import type { Connection } from '@libp2p/interface/connection' +import type { Connection, Stream } from '@libp2p/interface/connection' import type { ConnectionGater } from '@libp2p/interface/connection-gater' import type { ContentRouting } from '@libp2p/interface/content-routing' import type { PeerId } from '@libp2p/interface/peer-id' @@ -26,6 +28,9 @@ describe('circuit-relay stop protocol', function () { let handler: StreamHandler let pbstr: MessageStream let sourcePeer: PeerId + const stopTimeout = 100 + let localStream: Stream + let remoteStream: Stream beforeEach(async () => { const components = { @@ -41,7 +46,9 @@ describe('circuit-relay stop protocol', function () { events: new EventEmitter() } - transport = circuitRelayTransport({})(components) + transport = circuitRelayTransport({ + stopTimeout + })(components) if (isStartable(transport)) { await transport.start() @@ -51,10 +58,13 @@ describe('circuit-relay stop protocol', function () { handler = components.registrar.handle.getCall(0).args[1] - const [localStream, remoteStream] = duplexPair() + const [localDuplex, remoteDuplex] = duplexPair() + + localStream = mockStream(localDuplex) + remoteStream = mockStream(remoteDuplex) handler({ - stream: mockStream(remoteStream), + stream: remoteStream, connection: stubInterface() }) @@ -68,7 +78,7 @@ describe('circuit-relay stop protocol', function () { }) it('handle stop - success', async function () { - pbstr.write({ + await pbstr.write({ type: StopMessage.Type.CONNECT, peer: { id: sourcePeer.toBytes(), @@ -80,8 +90,15 @@ describe('circuit-relay stop protocol', function () { expect(response.status).to.be.equal(Status.OK) }) + it('handle stop error - invalid request - missing type', async function () { + await pbstr.write({}) + + const response = await pbstr.read() + expect(response.status).to.be.equal(Status.MALFORMED_MESSAGE) + }) + it('handle stop error - invalid request - wrong type', async function () { - pbstr.write({ + await pbstr.write({ type: StopMessage.Type.STATUS, peer: { id: sourcePeer.toBytes(), @@ -94,7 +111,7 @@ describe('circuit-relay stop protocol', function () { }) it('handle stop error - invalid request - missing peer', async function () { - pbstr.write({ + await pbstr.write({ type: StopMessage.Type.CONNECT }) @@ -103,7 +120,7 @@ describe('circuit-relay stop protocol', function () { }) it('handle stop error - invalid request - invalid peer addr', async function () { - pbstr.write({ + await pbstr.write({ type: StopMessage.Type.CONNECT, peer: { id: sourcePeer.toBytes(), @@ -116,4 +133,22 @@ describe('circuit-relay stop protocol', function () { const response = await pbstr.read() expect(response.status).to.be.equal(Status.MALFORMED_MESSAGE) }) + + it('handle stop error - timeout', async function () { + const abortSpy = Sinon.spy(remoteStream, 'abort') + + await pbstr.write({ + type: StopMessage.Type.CONNECT, + peer: { + id: sourcePeer.toBytes(), + addrs: [] + } + }) + + // take longer than `stopTimeout` to read the response + await delay(stopTimeout * 2) + + // should have aborted remote stream + expect(abortSpy).to.have.property('called', true) + }) }) diff --git a/packages/libp2p/test/circuit-relay/utils.ts b/packages/libp2p/test/circuit-relay/utils.ts index 818f3b036c..e11bd708f6 100644 --- a/packages/libp2p/test/circuit-relay/utils.ts +++ b/packages/libp2p/test/circuit-relay/utils.ts @@ -30,6 +30,16 @@ export async function usingAsRelay (node: Libp2p, relay: Libp2p, opts?: PWaitFor }, opts) } +export async function notUsingAsRelay (node: Libp2p, relay: Libp2p, opts?: PWaitForOptions): Promise { + // Wait for peer to be used as a relay + await pWaitFor(() => { + const search = `${relay.peerId.toString()}/p2p-circuit` + const relayAddrs = node.getMultiaddrs().filter(addr => addr.toString().includes(search)) + + return relayAddrs.length === 0 + }, opts) +} + export async function hasRelay (node: Libp2p, opts?: PWaitForOptions): Promise { let relayPeerId: PeerId | undefined @@ -70,6 +80,15 @@ export async function hasRelay (node: Libp2p, opts?: PWaitForOptions): P return relayPeerId } +export async function doesNotHaveRelay (node: Libp2p, opts?: PWaitForOptions): Promise { + // Wait for peer to be used as a relay + await pWaitFor(() => { + const relayAddrs = node.getMultiaddrs().filter(addr => addr.protoNames().includes('p2p-circuit')) + + return relayAddrs.length === 0 + }, opts) +} + export async function discoveredRelayConfig (node: Libp2p, relay: Libp2p, opts?: PWaitForOptions): Promise { await pWaitFor(async () => { try { diff --git a/packages/libp2p/test/connection-manager/direct.node.ts b/packages/libp2p/test/connection-manager/direct.node.ts index dfc290947c..c18a5147f6 100644 --- a/packages/libp2p/test/connection-manager/direct.node.ts +++ b/packages/libp2p/test/connection-manager/direct.node.ts @@ -7,6 +7,7 @@ import { yamux } from '@chainsafe/libp2p-yamux' import { type Connection, type ConnectionProtector, isConnection } from '@libp2p/interface/connection' import { AbortError } from '@libp2p/interface/errors' import { EventEmitter } from '@libp2p/interface/events' +import { start, stop } from '@libp2p/interface/startable' import { mockConnection, mockConnectionGater, mockDuplex, mockMultiaddrConnection, mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' import { mplex } from '@libp2p/mplex' import { peerIdFromString } from '@libp2p/peer-id' @@ -74,7 +75,7 @@ describe('dialing (direct, TCP)', () => { listenAddr.toString() ] }) - remoteTM = new DefaultTransportManager(remoteComponents) + remoteTM = remoteComponents.transportManager = new DefaultTransportManager(remoteComponents) remoteTM.add(tcp()()) const localEvents = new EventEmitter() @@ -92,18 +93,20 @@ describe('dialing (direct, TCP)', () => { minConnections: 50, inboundUpgradeTimeout: 1000 }) - - localTM = new DefaultTransportManager(localComponents) + localComponents.addressManager = new DefaultAddressManager(localComponents) + localTM = localComponents.transportManager = new DefaultTransportManager(localComponents) localTM.add(tcp()()) - localComponents.transportManager = localTM - - await remoteTM.listen([listenAddr]) + await start(localComponents) + await start(remoteComponents) remoteAddr = remoteTM.getAddrs()[0].encapsulate(`/p2p/${remotePeerId.toString()}`) }) - afterEach(async () => { await remoteTM.stop() }) + afterEach(async () => { + await stop(localComponents) + await stop(remoteComponents) + }) afterEach(() => { sinon.restore() @@ -405,10 +408,7 @@ describe('libp2p.dialer (direct, TCP)', () => { void pipe(stream, stream) }) - await libp2p.peerStore.patch(remotePeerId, { - multiaddrs: remoteLibp2p.getMultiaddrs() - }) - const connection = await libp2p.dial(remotePeerId) + const connection = await libp2p.dial(remoteLibp2p.getMultiaddrs()) // Create local to remote streams const stream = await connection.newStream('/echo/1.0.0') diff --git a/packages/libp2p/test/connection-manager/index.node.ts b/packages/libp2p/test/connection-manager/index.node.ts index 2612b6a17c..91118332cb 100644 --- a/packages/libp2p/test/connection-manager/index.node.ts +++ b/packages/libp2p/test/connection-manager/index.node.ts @@ -1,6 +1,5 @@ /* eslint-env mocha */ -import * as STATUS from '@libp2p/interface/connection/status' import { EventEmitter } from '@libp2p/interface/events' import { start } from '@libp2p/interface/startable' import { mockConnection, mockDuplex, mockMultiaddrConnection } from '@libp2p/interface-compliance-tests/mocks' @@ -80,7 +79,7 @@ describe('Connection Manager', () => { expect(connectionManager.getConnections(peerIds[1])).to.have.lengthOf(1) - expect(conn1).to.have.nested.property('status', STATUS.OPEN) + expect(conn1).to.have.nested.property('status', 'open') await connectionManager.stop() }) @@ -360,7 +359,7 @@ describe('libp2p.connections', () => { const conn = conns[0] await libp2p.stop() - expect(conn.status).to.eql(STATUS.CLOSED) + expect(conn.status).to.eql('closed') await remoteLibp2p.stop() }) diff --git a/packages/libp2p/test/connection-manager/index.spec.ts b/packages/libp2p/test/connection-manager/index.spec.ts index 080557beb5..699339da03 100644 --- a/packages/libp2p/test/connection-manager/index.spec.ts +++ b/packages/libp2p/test/connection-manager/index.spec.ts @@ -14,6 +14,7 @@ import { DefaultConnectionManager } from '../../src/connection-manager/index.js' import { createBaseOptions } from '../fixtures/base-options.browser.js' import { createNode } from '../fixtures/creators/peer.js' import type { Libp2pNode } from '../../src/libp2p.js' +import type { AbortOptions } from '@libp2p/interface' import type { Connection } from '@libp2p/interface/connection' import type { ConnectionGater } from '@libp2p/interface/connection-gater' import type { PeerStore } from '@libp2p/interface/peer-store' @@ -86,7 +87,7 @@ describe('Connection Manager', () => { const connectionManager = libp2p.components.connectionManager as DefaultConnectionManager const connectionManagerMaybePruneConnectionsSpy = sinon.spy(connectionManager.connectionPruner, 'maybePruneConnections') - const spies = new Map>>() + const spies = new Map>>() // wait for prune event const eventPromise = pEvent(libp2p, 'connection:prune') @@ -145,7 +146,7 @@ describe('Connection Manager', () => { const connectionManager = libp2p.components.connectionManager as DefaultConnectionManager const connectionManagerMaybePruneConnectionsSpy = sinon.spy(connectionManager.connectionPruner, 'maybePruneConnections') - const spies = new Map>>() + const spies = new Map>>() const eventPromise = pEvent(libp2p, 'connection:prune') const createConnection = async (value: number, open: number = Date.now(), peerTag: string = 'test-tag'): Promise => { @@ -213,7 +214,7 @@ describe('Connection Manager', () => { const connectionManager = libp2p.components.connectionManager as DefaultConnectionManager const connectionManagerMaybePruneConnectionsSpy = sinon.spy(connectionManager.connectionPruner, 'maybePruneConnections') - const spies = new Map>>() + const spies = new Map>>() const eventPromise = pEvent(libp2p, 'connection:prune') // Max out connections diff --git a/packages/libp2p/test/connection/compliance.spec.ts b/packages/libp2p/test/connection/compliance.spec.ts index 43cfe4e2e7..ca216acd5d 100644 --- a/packages/libp2p/test/connection/compliance.spec.ts +++ b/packages/libp2p/test/connection/compliance.spec.ts @@ -28,29 +28,31 @@ describe('connection compliance', () => { direction: 'outbound', encryption: '/secio/1.0.0', multiplexer: '/mplex/6.7.0', - status: 'OPEN', + status: 'open', newStream: async (protocols) => { const id = `${streamId++}` const stream: Stream = { ...pair(), - close: () => { + close: async () => { void stream.sink(async function * () {}()) connection.removeStream(stream.id) openStreams = openStreams.filter(s => s.id !== id) }, - closeRead: () => {}, - closeWrite: () => { + closeRead: async () => {}, + closeWrite: async () => { void stream.sink(async function * () {}()) }, id, abort: () => {}, - reset: () => {}, direction: 'outbound', protocol: protocols[0], timeline: { open: 0 }, - metadata: {} + metadata: {}, + status: 'open', + writeStatus: 'ready', + readStatus: 'ready' } openStreams.push(stream) @@ -58,6 +60,7 @@ describe('connection compliance', () => { return stream }, close: async () => {}, + abort: () => {}, getStreams: () => openStreams, ...properties }) diff --git a/packages/libp2p/test/connection/index.spec.ts b/packages/libp2p/test/connection/index.spec.ts index c6b11c2709..03c21202be 100644 --- a/packages/libp2p/test/connection/index.spec.ts +++ b/packages/libp2p/test/connection/index.spec.ts @@ -47,29 +47,31 @@ describe('connection', () => { direction: 'outbound', encryption: '/secio/1.0.0', multiplexer: '/mplex/6.7.0', - status: 'OPEN', + status: 'open', newStream: async (protocols) => { const id = `${streamId++}` const stream: Stream = { ...pair(), - close: () => { - void stream.sink(async function * () {}()) + close: async () => { + await stream.sink(async function * () {}()) openStreams = openStreams.filter(s => s.id !== id) }, - closeRead: () => {}, - closeWrite: () => { - void stream.sink(async function * () {}()) + closeRead: async () => {}, + closeWrite: async () => { + await stream.sink(async function * () {}()) }, id, abort: () => {}, - reset: () => {}, direction: 'outbound', protocol: protocols[0], timeline: { open: 0 }, - metadata: {} + metadata: {}, + status: 'open', + writeStatus: 'ready', + readStatus: 'ready' } openStreams.push(stream) @@ -77,6 +79,7 @@ describe('connection', () => { return stream }, close: async () => {}, + abort: () => {}, getStreams: () => openStreams }) }) diff --git a/packages/libp2p/test/identify/index.spec.ts b/packages/libp2p/test/identify/index.spec.ts index 386d9ce926..ed7d54ff09 100644 --- a/packages/libp2p/test/identify/index.spec.ts +++ b/packages/libp2p/test/identify/index.spec.ts @@ -13,8 +13,8 @@ import { MemoryDatastore } from 'datastore-core/memory' import delay from 'delay' import drain from 'it-drain' import * as lp from 'it-length-prefixed' -import { pbStream } from 'it-pb-stream' import { pipe } from 'it-pipe' +import { pbStream } from 'it-protobuf-stream' import pDefer from 'p-defer' import sinon from 'sinon' import { stubInterface } from 'sinon-ts' @@ -235,7 +235,7 @@ describe('identify', () => { }) it('should limit incoming identify message sizes', async () => { - const deferred = pDefer() + const deferred = pDefer() const remoteIdentify = identifyService({ ...defaultInit, @@ -252,14 +252,16 @@ describe('identify', () => { const data = new Uint8Array(1024) void Promise.resolve().then(async () => { - await pipe( - [data], - (source) => lp.encode(source), - stream, - async (source) => { await drain(source) } - ) - - deferred.resolve() + try { + await pipe( + [data], + (source) => lp.encode(source), + stream, + async (source) => { await drain(source) } + ) + } catch (err: any) { + deferred.resolve(err) + } }) }) @@ -271,7 +273,10 @@ describe('identify', () => { detail: remoteToLocal }) - await deferred.promise + // stream sending too much data should have received a reset message + const err = await deferred.promise + expect(err).to.have.property('code', 'ERR_STREAM_RESET') + await stop(remoteIdentify) expect(identifySpy.called).to.be.true() @@ -425,7 +430,7 @@ describe('identify', () => { remoteIdentify._handleIdentify = async (data: IncomingStreamData): Promise => { const { stream } = data const pb = pbStream(stream) - pb.writePB(message, Identify) + await pb.write(message, Identify) } // Run identify @@ -480,7 +485,7 @@ describe('identify', () => { remoteIdentify._handleIdentify = async (data: IncomingStreamData): Promise => { const { stream } = data const pb = pbStream(stream) - pb.writePB(message, Identify) + await pb.write(message, Identify) } // Run identify diff --git a/packages/libp2p/test/identify/service.node.ts b/packages/libp2p/test/identify/service.node.ts index c99c981b80..e7ce86b685 100644 --- a/packages/libp2p/test/identify/service.node.ts +++ b/packages/libp2p/test/identify/service.node.ts @@ -143,7 +143,7 @@ describe('identify', () => { expect(clientPeer.addresses[0].multiaddr.toString()).to.equal(announceAddrs[0].toString()) expect(clientPeer.addresses[1].multiaddr.toString()).to.equal(announceAddrs[1].toString()) - stream.close() + await stream.close() await connection.close() await receiver.stop() await sender.stop() diff --git a/packages/libp2p/test/transports/transport-manager.node.ts b/packages/libp2p/test/transports/transport-manager.node.ts index bff2e28fb7..7404e88bbe 100644 --- a/packages/libp2p/test/transports/transport-manager.node.ts +++ b/packages/libp2p/test/transports/transport-manager.node.ts @@ -1,6 +1,8 @@ /* eslint-env mocha */ import { EventEmitter } from '@libp2p/interface/events' +import { start, stop } from '@libp2p/interface/startable' +import { FaultTolerance } from '@libp2p/interface/transport' import { mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' import { createEd25519PeerId } from '@libp2p/peer-id-factory' import { PersistentPeerStore } from '@libp2p/peer-store' @@ -30,7 +32,7 @@ describe('Transport Manager (TCP)', () => { localPeer = await createEd25519PeerId() }) - beforeEach(() => { + beforeEach(async () => { const events = new EventEmitter() components = defaultComponents({ peerId: localPeer, @@ -41,14 +43,19 @@ describe('Transport Manager (TCP)', () => { components.addressManager = new DefaultAddressManager(components, { listen: addrs.map(addr => addr.toString()) }) components.peerStore = new PersistentPeerStore(components) - tm = new DefaultTransportManager(components) + tm = new DefaultTransportManager(components, { + faultTolerance: FaultTolerance.NO_FATAL + }) components.transportManager = tm + + await start(tm) }) afterEach(async () => { await tm.removeAll() expect(tm.getTransports()).to.be.empty() + await stop(tm) }) it('should be able to add and remove a transport', async () => { diff --git a/packages/libp2p/test/transports/transport-manager.spec.ts b/packages/libp2p/test/transports/transport-manager.spec.ts index 1a996f76f0..89024dd8bc 100644 --- a/packages/libp2p/test/transports/transport-manager.spec.ts +++ b/packages/libp2p/test/transports/transport-manager.spec.ts @@ -1,6 +1,7 @@ /* eslint-env mocha */ import { EventEmitter } from '@libp2p/interface/events' +import { start, stop } from '@libp2p/interface/startable' import { FaultTolerance } from '@libp2p/interface/transport' import { mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' import { createEd25519PeerId } from '@libp2p/peer-id-factory' @@ -24,7 +25,7 @@ describe('Transport Manager (WebSockets)', () => { let tm: DefaultTransportManager let components: Components - before(async () => { + beforeEach(async () => { const events = new EventEmitter() components = { peerId: await createEd25519PeerId(), @@ -33,11 +34,15 @@ describe('Transport Manager (WebSockets)', () => { } as any components.addressManager = new DefaultAddressManager(components, { listen: [listenAddr.toString()] }) - tm = new DefaultTransportManager(components) + tm = new DefaultTransportManager(components, { + faultTolerance: FaultTolerance.NO_FATAL + }) + await start(tm) }) afterEach(async () => { await tm.removeAll() + await stop(tm) expect(tm.getTransports()).to.be.empty() }) @@ -81,11 +86,14 @@ describe('Transport Manager (WebSockets)', () => { }) it('should fail to listen with no valid address', async () => { + tm = new DefaultTransportManager(components) tm.add(webSockets({ filter: filters.all })()) - await expect(tm.listen([listenAddr])) + await expect(start(tm)) .to.eventually.be.rejected() .and.to.have.property('code', ErrorCodes.ERR_NO_VALID_ADDRESSES) + + await stop(tm) }) }) diff --git a/packages/libp2p/test/upgrading/upgrader.spec.ts b/packages/libp2p/test/upgrading/upgrader.spec.ts index 45a5ce3f16..1ec2290f46 100644 --- a/packages/libp2p/test/upgrading/upgrader.spec.ts +++ b/packages/libp2p/test/upgrading/upgrader.spec.ts @@ -50,7 +50,7 @@ describe('Upgrader', () => { let localConnectionProtector: StubbedInstance let remoteUpgrader: Upgrader let remoteMuxerFactory: StreamMuxerFactory - let remotreYamuxerFactory: StreamMuxerFactory + let remoteYamuxerFactory: StreamMuxerFactory let remoteConnectionEncrypter: ConnectionEncrypter let remoteConnectionProtector: StubbedInstance let localPeer: PeerId @@ -108,7 +108,7 @@ describe('Upgrader', () => { remoteComponents.peerStore = new PersistentPeerStore(remoteComponents) remoteComponents.connectionManager = mockConnectionManager(remoteComponents) remoteMuxerFactory = mplex()() - remotreYamuxerFactory = yamux()() + remoteYamuxerFactory = yamux()() remoteConnectionEncrypter = plaintext()() remoteUpgrader = new DefaultUpgrader(remoteComponents, { connectionEncryption: [ @@ -116,7 +116,7 @@ describe('Upgrader', () => { ], muxers: [ remoteMuxerFactory, - remotreYamuxerFactory + remoteYamuxerFactory ], inboundUpgradeTimeout: 1000 }) @@ -294,7 +294,8 @@ describe('Upgrader', () => { })() async sink (): Promise {} - close (): void {} + async close (): Promise {} + abort (): void {} } class OtherMuxerFactory implements StreamMuxerFactory { @@ -661,10 +662,10 @@ describe('libp2p.upgrader', () => { const remoteLibp2pUpgraderOnStreamSpy = sinon.spy(remoteComponents.upgrader as DefaultUpgrader, '_onStream') const stream = await localConnection.newStream(['/echo/1.0.0']) - expect(stream).to.include.keys(['id', 'recvWindowCapacity', 'sendWindowCapacity', 'sourceInput']) + expect(stream).to.include.keys(['id', 'sink', 'source']) const [arg0] = remoteLibp2pUpgraderOnStreamSpy.getCall(0).args - expect(arg0.stream).to.include.keys(['id', 'recvWindowCapacity', 'sendWindowCapacity', 'sourceInput']) + expect(arg0.stream).to.include.keys(['id', 'sink', 'source']) }) it('should emit connect and disconnect events', async () => { diff --git a/packages/libp2p/test/upnp-nat/upnp-nat.node.ts b/packages/libp2p/test/upnp-nat/upnp-nat.node.ts index dee88506e2..42d6cc6407 100644 --- a/packages/libp2p/test/upnp-nat/upnp-nat.node.ts +++ b/packages/libp2p/test/upnp-nat/upnp-nat.node.ts @@ -32,12 +32,12 @@ describe('UPnP NAT (TCP)', () => { async function createNatManager (addrs = DEFAULT_ADDRESSES, natManagerOptions = {}): Promise<{ natManager: any, components: Components }> { const events = new EventEmitter() - const components: any = { + const components: any = defaultComponents({ peerId: await createEd25519PeerId(), upgrader: mockUpgrader({ events }), events, peerStore: stubInterface() - } + }) components.peerStore.patch.callsFake(async (peerId: PeerId, details: PeerData) => { components.events.safeDispatchEvent('self:peer:update', { @@ -65,11 +65,13 @@ describe('UPnP NAT (TCP)', () => { } components.transportManager.add(tcp()()) - await components.transportManager.listen(components.addressManager.getListenAddrs()) + + await start(components) teardown.push(async () => { await stop(natManager) await components.transportManager.removeAll() + await stop(components) }) return { diff --git a/packages/multistream-select/package.json b/packages/multistream-select/package.json index b0effbc562..07acef9b10 100644 --- a/packages/multistream-select/package.json +++ b/packages/multistream-select/package.json @@ -60,7 +60,7 @@ "it-length-prefixed": "^9.0.1", "it-merge": "^3.0.0", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "it-reader": "^6.0.1", "it-stream-types": "^2.0.1", "uint8arraylist": "^2.4.3", diff --git a/packages/pubsub-gossipsub/package.json b/packages/pubsub-gossipsub/package.json index 26bc559748..17c4f91138 100644 --- a/packages/pubsub-gossipsub/package.json +++ b/packages/pubsub-gossipsub/package.json @@ -93,7 +93,7 @@ "denque": "^1.5.0", "it-length-prefixed": "^9.0.1", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "multiformats": "^12.0.1", "protobufjs": "^6.11.2", "uint8arraylist": "^2.4.3", diff --git a/packages/pubsub-gossipsub/src/index.ts b/packages/pubsub-gossipsub/src/index.ts index 7d7c1db25d..7261abeb06 100644 --- a/packages/pubsub-gossipsub/src/index.ts +++ b/packages/pubsub-gossipsub/src/index.ts @@ -710,7 +710,7 @@ export class GossipSub extends EventEmitter implements PubSub { objectMode: true, onEnd: (shouldEmit) => { // close writable side of the stream - if (this._rawOutboundStream != null && this._rawOutboundStream.reset != null) { // eslint-disable-line @typescript-eslint/prefer-optional-chain - this._rawOutboundStream.reset() + if (this._rawOutboundStream != null) { // eslint-disable-line @typescript-eslint/prefer-optional-chain + this._rawOutboundStream.closeWrite() + .catch(err => { + log('error closing outbound stream', err) + }) } this._rawOutboundStream = undefined diff --git a/packages/pubsub/test/utils/index.ts b/packages/pubsub/test/utils/index.ts index 4b112f9337..b98c78b76f 100644 --- a/packages/pubsub/test/utils/index.ts +++ b/packages/pubsub/test/utils/index.ts @@ -134,14 +134,16 @@ export const ConnectionPair = (): [Connection, Connection] => { // @ts-expect-error incomplete implementation newStream: async (protocol: string[]) => Promise.resolve({ ...d0, - protocol: protocol[0] + protocol: protocol[0], + closeWrite: async () => {} }) }, { // @ts-expect-error incomplete implementation newStream: async (protocol: string[]) => Promise.resolve({ ...d1, - protocol: protocol[0] + protocol: protocol[0], + closeWrite: async () => {} }) } ] diff --git a/packages/stream-multiplexer-mplex/package.json b/packages/stream-multiplexer-mplex/package.json index 89ac185871..985efaf975 100644 --- a/packages/stream-multiplexer-mplex/package.json +++ b/packages/stream-multiplexer-mplex/package.json @@ -59,10 +59,9 @@ "@libp2p/interface": "~0.0.1", "@libp2p/logger": "^2.0.0", "abortable-iterator": "^5.0.1", - "any-signal": "^4.1.1", "benchmark": "^2.1.4", "it-batched-bytes": "^2.0.2", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "it-stream-types": "^2.0.1", "rate-limiter-flexible": "^2.3.11", "uint8arraylist": "^2.4.3", diff --git a/packages/stream-multiplexer-mplex/src/mplex.ts b/packages/stream-multiplexer-mplex/src/mplex.ts index ec45a0400a..41ad0123be 100644 --- a/packages/stream-multiplexer-mplex/src/mplex.ts +++ b/packages/stream-multiplexer-mplex/src/mplex.ts @@ -1,15 +1,16 @@ import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import { anySignal } from 'any-signal' -import { pushableV } from 'it-pushable' +import { pipe } from 'it-pipe' +import { type PushableV, pushableV } from 'it-pushable' import { RateLimiterMemory } from 'rate-limiter-flexible' import { toString as uint8ArrayToString } from 'uint8arrays' import { Decoder } from './decode.js' import { encode } from './encode.js' import { MessageTypes, MessageTypeNames, type Message } from './message-types.js' -import { createStream } from './stream.js' +import { createStream, type MplexStream } from './stream.js' import type { MplexInit } from './index.js' +import type { AbortOptions } from '@libp2p/interface' import type { Stream } from '@libp2p/interface/connection' import type { StreamMuxer, StreamMuxerInit } from '@libp2p/interface/stream-muxer' import type { Sink, Source } from 'it-stream-types' @@ -21,6 +22,7 @@ const MAX_STREAMS_INBOUND_STREAMS_PER_CONNECTION = 1024 const MAX_STREAMS_OUTBOUND_STREAMS_PER_CONNECTION = 1024 const MAX_STREAM_BUFFER_SIZE = 1024 * 1024 * 4 // 4MB const DISCONNECT_THRESHOLD = 5 +const CLOSE_TIMEOUT = 500 function printMessage (msg: Message): any { const output: any = { @@ -39,13 +41,13 @@ function printMessage (msg: Message): any { return output } -export interface MplexStream extends Stream { - sourceReadableLength: () => number - sourcePush: (data: Uint8ArrayList) => void +interface MplexStreamMuxerInit extends MplexInit, StreamMuxerInit { + /** + * The default timeout to use in ms when shutting down the muxer. + */ + closeTimeout?: number } -interface MplexStreamMuxerInit extends MplexInit, StreamMuxerInit {} - export class MplexStreamMuxer implements StreamMuxer { public protocol = '/mplex/6.7.0' @@ -55,9 +57,10 @@ export class MplexStreamMuxer implements StreamMuxer { private _streamId: number private readonly _streams: { initiators: Map, receivers: Map } private readonly _init: MplexStreamMuxerInit - private readonly _source: { push: (val: Message) => void, end: (err?: Error) => void } + private readonly _source: PushableV private readonly closeController: AbortController private readonly rateLimiter: RateLimiterMemory + private readonly closeTimeout: number constructor (init?: MplexStreamMuxerInit) { init = init ?? {} @@ -74,6 +77,7 @@ export class MplexStreamMuxer implements StreamMuxer { receivers: new Map() } this._init = init + this.closeTimeout = init.closeTimeout ?? CLOSE_TIMEOUT /** * An iterable sink @@ -83,9 +87,24 @@ export class MplexStreamMuxer implements StreamMuxer { /** * An iterable source */ - const source = this._createSource() - this._source = source - this.source = source + this._source = pushableV({ + objectMode: true, + onEnd: (): void => { + // the source has ended, we can't write any more messages to gracefully + // close streams so all we can do is destroy them + for (const stream of this._streams.initiators.values()) { + stream.destroy() + } + + for (const stream of this._streams.receivers.values()) { + stream.destroy() + } + } + }) + this.source = pipe( + this._source, + source => encode(source, this._init.minSendBytes) + ) /** * Close controller @@ -131,15 +150,41 @@ export class MplexStreamMuxer implements StreamMuxer { /** * Close or abort all tracked streams and stop the muxer */ - close (err?: Error | undefined): void { - if (this.closeController.signal.aborted) return + async close (options?: AbortOptions): Promise { + if (this.closeController.signal.aborted) { + return + } + + const signal = options?.signal ?? AbortSignal.timeout(this.closeTimeout) + + try { + // try to gracefully close all streams + await Promise.all( + this.streams.map(async s => s.close({ + signal + })) + ) - if (err != null) { - this.streams.forEach(s => { s.abort(err) }) - } else { - this.streams.forEach(s => { s.close() }) + this._source.end() + + // try to gracefully close the muxer + await this._source.onEmpty({ + signal + }) + + this.closeController.abort() + } catch (err: any) { + this.abort(err) } - this.closeController.abort() + } + + abort (err: Error): void { + if (this.closeController.signal.aborted) { + return + } + + this.streams.forEach(s => { s.abort(err) }) + this.closeController.abort(err) } /** @@ -164,7 +209,7 @@ export class MplexStreamMuxer implements StreamMuxer { throw new Error(`${type} stream ${id} already exists!`) } - const send = (msg: Message): void => { + const send = async (msg: Message): Promise => { if (log.enabled) { log.trace('%s stream %s send', type, id, printMessage(msg)) } @@ -192,10 +237,10 @@ export class MplexStreamMuxer implements StreamMuxer { */ _createSink (): Sink, Promise> { const sink: Sink, Promise> = async source => { - const signal = anySignal([this.closeController.signal, this._init.signal]) - try { - source = abortableSource(source, signal) + source = abortableSource(source, this.closeController.signal, { + returnOnAbort: true + }) const decoder = new Decoder(this._init.maxMsgSize, this._init.maxUnprocessedMessageQueueSize) @@ -209,34 +254,12 @@ export class MplexStreamMuxer implements StreamMuxer { } catch (err: any) { log('error in sink', err) this._source.end(err) // End the source with an error - } finally { - signal.clear() } } return sink } - /** - * Creates a source that restricts outgoing message sizes - * and varint encodes them - */ - _createSource (): any { - const onEnd = (err?: Error): void => { - this.close(err) - } - const source = pushableV({ - objectMode: true, - onEnd - }) - - return Object.assign(encode(source, this._init.minSendBytes), { - push: source.push, - end: source.end, - return: source.return - }) - } - async _handleIncoming (message: Message): Promise { const { id, type } = message @@ -264,7 +287,7 @@ export class MplexStreamMuxer implements StreamMuxer { } catch { log('rate limit hit when opening too many new streams over the inbound stream limit - closing remote connection') // since there's no backpressure in mplex, the only thing we can really do to protect ourselves is close the connection - this._source.end(new Error('Too many open streams')) + this.abort(new Error('Too many open streams')) return } @@ -286,43 +309,57 @@ export class MplexStreamMuxer implements StreamMuxer { if (stream == null) { log('missing stream %s for message type %s', id, MessageTypeNames[type]) + // if the remote keeps sending us messages for streams that have been + // closed or were never opened they may be attacking us so if they do + // this very quickly all we can do is close the connection + try { + await this.rateLimiter.consume('missing-stream', 1) + } catch { + log('rate limit hit when receiving messages for streams that do not exist - closing remote connection') + // since there's no backpressure in mplex, the only thing we can really do to protect ourselves is close the connection + this.abort(new Error('Too many messages for missing streams')) + return + } + return } const maxBufferSize = this._init.maxStreamBufferSize ?? MAX_STREAM_BUFFER_SIZE - switch (type) { - case MessageTypes.MESSAGE_INITIATOR: - case MessageTypes.MESSAGE_RECEIVER: - if (stream.sourceReadableLength() > maxBufferSize) { - // Stream buffer has got too large, reset the stream - this._source.push({ - id: message.id, - type: type === MessageTypes.MESSAGE_INITIATOR ? MessageTypes.RESET_RECEIVER : MessageTypes.RESET_INITIATOR - }) - - // Inform the stream consumer they are not fast enough - const error = new CodeError('Input buffer full - increase Mplex maxBufferSize to accommodate slow consumers', 'ERR_STREAM_INPUT_BUFFER_FULL') - stream.abort(error) - - return - } + try { + switch (type) { + case MessageTypes.MESSAGE_INITIATOR: + case MessageTypes.MESSAGE_RECEIVER: + if (stream.sourceReadableLength() > maxBufferSize) { + // Stream buffer has got too large, reset the stream + this._source.push({ + id: message.id, + type: type === MessageTypes.MESSAGE_INITIATOR ? MessageTypes.RESET_RECEIVER : MessageTypes.RESET_INITIATOR + }) + + // Inform the stream consumer they are not fast enough + throw new CodeError('Input buffer full - increase Mplex maxBufferSize to accommodate slow consumers', 'ERR_STREAM_INPUT_BUFFER_FULL') + } - // We got data from the remote, push it into our local stream - stream.sourcePush(message.data) - break - case MessageTypes.CLOSE_INITIATOR: - case MessageTypes.CLOSE_RECEIVER: - // We should expect no more data from the remote, stop reading - stream.closeRead() - break - case MessageTypes.RESET_INITIATOR: - case MessageTypes.RESET_RECEIVER: - // Stop reading and writing to the stream immediately - stream.reset() - break - default: - log('unknown message type %s', type) + // We got data from the remote, push it into our local stream + stream.sourcePush(message.data) + break + case MessageTypes.CLOSE_INITIATOR: + case MessageTypes.CLOSE_RECEIVER: + // The remote has stopped writing, so we can stop reading + stream.remoteCloseWrite() + break + case MessageTypes.RESET_INITIATOR: + case MessageTypes.RESET_RECEIVER: + // The remote has errored, stop reading and writing to the stream immediately + stream.reset() + break + default: + log('unknown message type %s', type) + } + } catch (err: any) { + log.error('error while processing message', err) + stream.abort(err) } } } diff --git a/packages/stream-multiplexer-mplex/src/stream.ts b/packages/stream-multiplexer-mplex/src/stream.ts index 45fc21b57a..cc8024e3cf 100644 --- a/packages/stream-multiplexer-mplex/src/stream.ts +++ b/packages/stream-multiplexer-mplex/src/stream.ts @@ -1,4 +1,5 @@ import { AbstractStream, type AbstractStreamInit } from '@libp2p/interface/stream-muxer/stream' +import { logger } from '@libp2p/logger' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { MAX_MSG_SIZE } from './decode.js' @@ -7,7 +8,7 @@ import type { Message } from './message-types.js' export interface Options { id: number - send: (msg: Message) => void + send: (msg: Message) => Promise name?: string onEnd?: (err?: Error) => void type?: 'initiator' | 'receiver' @@ -17,14 +18,21 @@ export interface Options { interface MplexStreamInit extends AbstractStreamInit { streamId: number name: string - send: (msg: Message) => void + send: (msg: Message) => Promise + + /** + * The maximum allowable data size, any data larger than this will be + * chunked and sent in multiple data messages + */ + maxDataSize: number } -class MplexStream extends AbstractStream { +export class MplexStream extends AbstractStream { private readonly name: string private readonly streamId: number - private readonly send: (msg: Message) => void + private readonly send: (msg: Message) => Promise private readonly types: Record + private readonly maxDataSize: number constructor (init: MplexStreamInit) { super(init) @@ -33,25 +41,37 @@ class MplexStream extends AbstractStream { this.send = init.send this.name = init.name this.streamId = init.streamId + this.maxDataSize = init.maxDataSize } - sendNewStream (): void { - this.send({ id: this.streamId, type: InitiatorMessageTypes.NEW_STREAM, data: new Uint8ArrayList(uint8ArrayFromString(this.name)) }) + async sendNewStream (): Promise { + await this.send({ id: this.streamId, type: InitiatorMessageTypes.NEW_STREAM, data: new Uint8ArrayList(uint8ArrayFromString(this.name)) }) } - sendData (data: Uint8ArrayList): void { - this.send({ id: this.streamId, type: this.types.MESSAGE, data }) + async sendData (data: Uint8ArrayList): Promise { + data = data.sublist() + + while (data.byteLength > 0) { + const toSend = Math.min(data.byteLength, this.maxDataSize) + await this.send({ + id: this.streamId, + type: this.types.MESSAGE, + data: data.sublist(0, toSend) + }) + + data.consume(toSend) + } } - sendReset (): void { - this.send({ id: this.streamId, type: this.types.RESET }) + async sendReset (): Promise { + await this.send({ id: this.streamId, type: this.types.RESET }) } - sendCloseWrite (): void { - this.send({ id: this.streamId, type: this.types.CLOSE }) + async sendCloseWrite (): Promise { + await this.send({ id: this.streamId, type: this.types.CLOSE }) } - sendCloseRead (): void { + async sendCloseRead (): Promise { // mplex does not support close read, only close write } } @@ -66,6 +86,7 @@ export function createStream (options: Options): MplexStream { direction: type === 'initiator' ? 'outbound' : 'inbound', maxDataSize: maxMsgSize, onEnd, - send + send, + log: logger(`libp2p:mplex:stream:${type}:${id}`) }) } diff --git a/packages/stream-multiplexer-mplex/test/mplex.spec.ts b/packages/stream-multiplexer-mplex/test/mplex.spec.ts index d6e9a63bb9..ae7b3bfb03 100644 --- a/packages/stream-multiplexer-mplex/test/mplex.spec.ts +++ b/packages/stream-multiplexer-mplex/test/mplex.spec.ts @@ -150,14 +150,19 @@ describe('mplex', () => { // collect outgoing mplex messages const muxerFinished = pDefer() - let messages: Message[] = [] + const messages: Message[] = [] void Promise.resolve().then(async () => { - messages = await all(decode()(muxer.source)) + try { + // collect as many messages as possible + for await (const message of decode()(muxer.source)) { + messages.push(message) + } + } catch {} muxerFinished.resolve() }) // the muxer processes the messages - await muxer.sink(encode(input)) + void muxer.sink(encode(input)) // source should have errored with appropriate code const err = await streamSourceError.promise @@ -201,7 +206,7 @@ describe('mplex', () => { await stream.sink(async function * () { yield * input }()) - stream.close() + await stream.close() streamFinished.resolve() }) diff --git a/packages/stream-multiplexer-mplex/test/stream.spec.ts b/packages/stream-multiplexer-mplex/test/stream.spec.ts index ee6e97e2dd..0009c157c3 100644 --- a/packages/stream-multiplexer-mplex/test/stream.spec.ts +++ b/packages/stream-multiplexer-mplex/test/stream.spec.ts @@ -15,7 +15,7 @@ import { MessageTypes, MessageTypeNames } from '../src/message-types.js' import { createStream } from '../src/stream.js' import { messageWithBytes } from './fixtures/utils.js' import type { Message } from '../src/message-types.js' -import type { MplexStream } from '../src/mplex.js' +import type { MplexStream } from '../src/stream.js' function randomInput (min = 1, max = 100): Uint8ArrayList[] { return Array.from(Array(randomInt(min, max)), () => new Uint8ArrayList(randomBytes(randomInt(1, 128)))) @@ -25,28 +25,6 @@ function expectMsgType (actual: keyof typeof MessageTypeNames, expected: keyof t expect(MessageTypeNames[actual]).to.equal(MessageTypeNames[expected]) } -function echoedMessage (message: Message): Message { - if (message.type !== MessageTypes.MESSAGE_RECEIVER) { - throw new Error('Message was not a receiver message') - } - - return bufferToMessage(message.data.slice()) -} - -function expectMessages (messages: Message[], codes: Array): void { - messages.slice(0, codes.length).forEach((msg, index) => { - expect(msg).to.have.property('type', codes[index]) - - if (msg.type === MessageTypes.MESSAGE_INITIATOR) { - expect(messageWithBytes(msg)).to.have.property('data').that.equalBytes([index - 1]) - } - }) -} - -function expectEchoedMessages (messages: Message[], codes: Array): void { - expectMessages(messages.slice(0, codes.length).map(echoedMessage), codes) -} - const msgToBuffer = (msg: Message): Uint8ArrayList => { const m: any = { ...msg @@ -66,17 +44,17 @@ interface onMessage { } export interface StreamPair { - initiatorMessages: Message[] - receiverMessages: Message[] + initiatorSentMessages: Message[] + receiverSentMessages: Message[] } async function streamPair (n: number, onInitiatorMessage?: onMessage, onReceiverMessage?: onMessage): Promise { - const receiverMessages: Message[] = [] - const initiatorMessages: Message[] = [] + const receiverSentMessages: Message[] = [] + const initiatorSentMessages: Message[] = [] const id = 5 - const mockInitiatorSend = (msg: Message): void => { - initiatorMessages.push(msg) + const mockInitiatorSend = async (msg: Message): Promise => { + initiatorSentMessages.push(msg) if (onInitiatorMessage != null) { onInitiatorMessage(msg, initiator, receiver) @@ -84,8 +62,8 @@ async function streamPair (n: number, onInitiatorMessage?: onMessage, onReceiver receiver.sourcePush(msgToBuffer(msg)) } - const mockReceiverSend = (msg: Message): void => { - receiverMessages.push(msg) + const mockReceiverSend = async (msg: Message): Promise => { + receiverSentMessages.push(msg) if (onReceiverMessage != null) { onReceiverMessage(msg, initiator, receiver) @@ -104,7 +82,7 @@ async function streamPair (n: number, onInitiatorMessage?: onMessage, onReceiver // when the initiator sends a CLOSE message, we call close if (msg.type === MessageTypes.CLOSE_INITIATOR) { - receiver.closeRead() + receiver.remoteCloseWrite() } // when the initiator sends a RESET message, we call close @@ -124,7 +102,7 @@ async function streamPair (n: number, onInitiatorMessage?: onMessage, onReceiver // when the receiver sends a CLOSE message, we call close if (msg.type === MessageTypes.CLOSE_RECEIVER) { - initiator.close() + initiator.remoteCloseWrite() } // when the receiver sends a RESET message, we call close @@ -139,15 +117,15 @@ async function streamPair (n: number, onInitiatorMessage?: onMessage, onReceiver } return { - receiverMessages, - initiatorMessages + receiverSentMessages, + initiatorSentMessages } } describe('stream', () => { it('should initiate stream with NEW_STREAM message', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const stream = createStream({ id, send: mockSend }) const input = randomInput() @@ -161,7 +139,7 @@ describe('stream', () => { it('should initiate named stream with NEW_STREAM message', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = `STREAM${Date.now()}` const stream = createStream({ id, name, send: mockSend }) @@ -176,7 +154,7 @@ describe('stream', () => { it('should end a stream when it is aborted', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = `STREAM${Date.now()}` const deferred = defer() @@ -191,7 +169,7 @@ describe('stream', () => { it('should end a stream when it is reset', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = `STREAM${Date.now()}` const deferred = defer() @@ -206,7 +184,7 @@ describe('stream', () => { it('should send data with MESSAGE_INITIATOR messages if stream initiator', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'initiator' }) @@ -227,7 +205,7 @@ describe('stream', () => { it('should send data with MESSAGE_RECEIVER messages if stream receiver', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'receiver' }) @@ -248,7 +226,7 @@ describe('stream', () => { it('should close stream with CLOSE_INITIATOR message if stream initiator', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'initiator' }) @@ -265,7 +243,7 @@ describe('stream', () => { it('should close stream with CLOSE_RECEIVER message if stream receiver', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'receiver' }) @@ -282,7 +260,7 @@ describe('stream', () => { it('should reset stream on error with RESET_INITIATOR message if stream initiator', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'initiator' }) @@ -308,7 +286,7 @@ describe('stream', () => { it('should reset stream on error with RESET_RECEIVER message if stream receiver', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = id.toString() const stream = createStream({ id, name, send: mockSend, type: 'receiver' }) @@ -335,13 +313,12 @@ describe('stream', () => { it('should close for reading (remote close)', async () => { const dataLength = 5 const { - initiatorMessages, - receiverMessages + initiatorSentMessages, + receiverSentMessages } = await streamPair(dataLength) // 1x NEW_STREAM, dataLength x MESSAGE_INITIATOR 1x CLOSE_INITIATOR - expect(initiatorMessages).to.have.lengthOf(1 + dataLength + 1) - expectMessages(initiatorMessages, [ + expect(initiatorSentMessages.map(m => m.type)).to.deep.equal([ MessageTypes.NEW_STREAM, MessageTypes.MESSAGE_INITIATOR, MessageTypes.MESSAGE_INITIATOR, @@ -351,18 +328,17 @@ describe('stream', () => { MessageTypes.CLOSE_INITIATOR ]) - // all the initiator messages plus CLOSE_RECEIVER - expect(receiverMessages).to.have.lengthOf(8) - expectEchoedMessages(receiverMessages, [ - MessageTypes.NEW_STREAM, - MessageTypes.MESSAGE_INITIATOR, - MessageTypes.MESSAGE_INITIATOR, - MessageTypes.MESSAGE_INITIATOR, - MessageTypes.MESSAGE_INITIATOR, - MessageTypes.MESSAGE_INITIATOR, - MessageTypes.CLOSE_INITIATOR + // echoes the initiator messages back plus CLOSE_RECEIVER + expect(receiverSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.CLOSE_RECEIVER ]) - expect(receiverMessages[receiverMessages.length - 1]).to.have.property('type', MessageTypes.CLOSE_RECEIVER) }) it('should close for reading and writing (abort on local error)', async () => { @@ -372,8 +348,8 @@ describe('stream', () => { const dataLength = 5 const { - initiatorMessages, - receiverMessages + initiatorSentMessages, + receiverSentMessages } = await streamPair(dataLength, (initiatorMessage, initiator) => { messages++ @@ -382,16 +358,14 @@ describe('stream', () => { } }) - expect(initiatorMessages).to.have.lengthOf(3) - expect(initiatorMessages[0]).to.have.property('type', MessageTypes.NEW_STREAM) - expect(initiatorMessages[1]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[2]).to.have.property('type', MessageTypes.RESET_INITIATOR) - - // Reset after two messages - expect(receiverMessages).to.have.lengthOf(2) - expectEchoedMessages(receiverMessages, [ + expect(initiatorSentMessages.map(m => m.type)).to.deep.equal([ MessageTypes.NEW_STREAM, - MessageTypes.MESSAGE_INITIATOR + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.RESET_INITIATOR + ]) + + expect(receiverSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.MESSAGE_RECEIVER ]) }) @@ -402,8 +376,8 @@ describe('stream', () => { const dataLength = 5 const { - initiatorMessages, - receiverMessages + initiatorSentMessages, + receiverSentMessages } = await streamPair(dataLength, (initiatorMessage, initiator, recipient) => { messages++ @@ -413,22 +387,21 @@ describe('stream', () => { }) // All messages sent to recipient - expect(initiatorMessages).to.have.lengthOf(7) - expect(initiatorMessages[0]).to.have.property('type', MessageTypes.NEW_STREAM) - expect(initiatorMessages[1]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[2]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[3]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[4]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[5]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[6]).to.have.property('type', MessageTypes.CLOSE_INITIATOR) - - // Recipient reset after two messages - expect(receiverMessages).to.have.lengthOf(3) - expectEchoedMessages(receiverMessages, [ + expect(initiatorSentMessages.map(m => m.type)).to.deep.equal([ MessageTypes.NEW_STREAM, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, MessageTypes.MESSAGE_INITIATOR ]) - expect(receiverMessages[receiverMessages.length - 1]).to.have.property('type', MessageTypes.RESET_RECEIVER) + + // Recipient reset after two messages + expect(receiverSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.RESET_RECEIVER + ]) }) it('should close immediately for reading and writing (reset on local error)', async () => { @@ -438,8 +411,8 @@ describe('stream', () => { const dataLength = 5 const { - initiatorMessages, - receiverMessages + initiatorSentMessages, + receiverSentMessages } = await streamPair(dataLength, () => { messages++ @@ -448,15 +421,15 @@ describe('stream', () => { } }) - expect(initiatorMessages).to.have.lengthOf(3) - expect(initiatorMessages[0]).to.have.property('type', MessageTypes.NEW_STREAM) - expect(initiatorMessages[1]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[2]).to.have.property('type', MessageTypes.RESET_INITIATOR) + expect(initiatorSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.NEW_STREAM, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.RESET_INITIATOR + ]) // Reset after two messages - expect(receiverMessages).to.have.lengthOf(1) - expectEchoedMessages(receiverMessages, [ - MessageTypes.NEW_STREAM + expect(receiverSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.MESSAGE_RECEIVER ]) }) @@ -467,8 +440,8 @@ describe('stream', () => { const dataLength = 5 const { - initiatorMessages, - receiverMessages + initiatorSentMessages, + receiverSentMessages } = await streamPair(dataLength, () => {}, () => { messages++ @@ -478,29 +451,28 @@ describe('stream', () => { }) // All messages sent to recipient - expect(initiatorMessages).to.have.lengthOf(7) - expect(initiatorMessages[0]).to.have.property('type', MessageTypes.NEW_STREAM) - expect(initiatorMessages[1]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[2]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[3]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[4]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[5]).to.have.property('type', MessageTypes.MESSAGE_INITIATOR) - expect(initiatorMessages[6]).to.have.property('type', MessageTypes.CLOSE_INITIATOR) - - // Recipient reset after two messages - expect(receiverMessages).to.have.lengthOf(3) - expectEchoedMessages(receiverMessages, [ + expect(initiatorSentMessages.map(m => m.type)).to.deep.equal([ MessageTypes.NEW_STREAM, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, + MessageTypes.MESSAGE_INITIATOR, MessageTypes.MESSAGE_INITIATOR ]) - expect(receiverMessages[receiverMessages.length - 1]).to.have.property('type', MessageTypes.RESET_RECEIVER) + + // Recipient reset after two messages + expect(receiverSentMessages.map(m => m.type)).to.deep.equal([ + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.MESSAGE_RECEIVER, + MessageTypes.RESET_RECEIVER + ]) }) it('should call onEnd only when both sides have closed', async () => { - const send = (msg: Message): void => { + const send = async (msg: Message): Promise => { if (msg.type === MessageTypes.CLOSE_INITIATOR) { // simulate remote closing connection - stream.closeRead() + await stream.closeRead() } else if (msg.type === MessageTypes.MESSAGE_INITIATOR) { stream.sourcePush(msgToBuffer(msg)) } @@ -522,7 +494,7 @@ describe('stream', () => { }) it('should call onEnd with error for local error', async () => { - const send = (): void => { + const send = async (): Promise => { throw new Error(`Local boom ${Date.now()}`) } const id = randomInt(1000) @@ -543,9 +515,9 @@ describe('stream', () => { it('should split writes larger than max message size', async () => { const messages: Message[] = [] - const send = (msg: Message): void => { + const send = async (msg: Message): Promise => { if (msg.type === MessageTypes.CLOSE_INITIATOR) { - stream.closeRead() + await stream.closeRead() } else if (msg.type === MessageTypes.MESSAGE_INITIATOR) { messages.push(msg) } @@ -567,8 +539,21 @@ describe('stream', () => { expect(messages[1]).to.have.nested.property('data.length', maxMsgSize) }) - it('should error on double-sink', async () => { - const send = (): void => {} + it('should error on double sink', async () => { + const send = async (): Promise => {} + const id = randomInt(1000) + const stream = createStream({ id, send }) + + // first sink is ok + void stream.sink([]) + + // cannot sink twice + await expect(stream.sink([])) + .to.eventually.be.rejected.with.property('code', 'ERR_SINK_INVALID_STATE') + }) + + it('should error on double sink after sink has ended', async () => { + const send = async (): Promise => {} const id = randomInt(1000) const stream = createStream({ id, send }) @@ -577,12 +562,12 @@ describe('stream', () => { // cannot sink twice await expect(stream.sink([])) - .to.eventually.be.rejected.with.property('code', 'ERR_DOUBLE_SINK') + .to.eventually.be.rejected.with.property('code', 'ERR_SINK_INVALID_STATE') }) it('should chunk really big messages', async () => { const msgs: Message[] = [] - const mockSend = (msg: Message): void => { msgs.push(msg) } + const mockSend = async (msg: Message): Promise => { msgs.push(msg) } const id = randomInt(1000) const name = `STREAM${Date.now()}` const maxMsgSize = 10 diff --git a/packages/stream-multiplexer-yamux/package.json b/packages/stream-multiplexer-yamux/package.json index 325ca1c695..c97dbb26f2 100644 --- a/packages/stream-multiplexer-yamux/package.json +++ b/packages/stream-multiplexer-yamux/package.json @@ -168,9 +168,9 @@ "@libp2p/interface": "~0.0.1", "@libp2p/logger": "^2.0.0", "abortable-iterator": "^5.0.1", - "any-signal": "^4.1.1", + "it-foreach": "^2.0.3", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "uint8arraylist": "^2.4.3" }, "devDependencies": { diff --git a/packages/stream-multiplexer-yamux/src/muxer.ts b/packages/stream-multiplexer-yamux/src/muxer.ts index 24d10cab21..83009366b0 100644 --- a/packages/stream-multiplexer-yamux/src/muxer.ts +++ b/packages/stream-multiplexer-yamux/src/muxer.ts @@ -1,6 +1,6 @@ import { CodeError } from '@libp2p/interface/errors' +import { logger, type Logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import { anySignal, type ClearableSignal } from 'any-signal' import { pipe } from 'it-pipe' import { pushable, type Pushable } from 'it-pushable' import { type Config, defaultConfig, verifyConfig } from './config.js' @@ -9,13 +9,14 @@ import { Decoder } from './decode.js' import { encodeHeader } from './encode.js' import { Flag, type FrameHeader, FrameType, GoAwayCode, stringifyHeader } from './frame.js' import { StreamState, YamuxStream } from './stream.js' +import type { AbortOptions } from '@libp2p/interface' import type { Stream } from '@libp2p/interface/connection' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface/stream-muxer' -import type { Logger } from '@libp2p/logger' import type { Sink, Source } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' const YAMUX_PROTOCOL_ID = '/yamux/1.0.0' +const CLOSE_TIMEOUT = 500 export interface YamuxMuxerInit extends StreamMuxerInit, Partial { } @@ -36,12 +37,15 @@ export class Yamux implements StreamMuxerFactory { } } +export interface CloseOptions extends AbortOptions { + reason?: GoAwayCode +} + export class YamuxMuxer implements StreamMuxer { protocol = YAMUX_PROTOCOL_ID source: Pushable sink: Sink, Promise> - private readonly _init: YamuxMuxerInit private readonly config: Config private readonly log?: Logger @@ -75,7 +79,6 @@ export class YamuxMuxer implements StreamMuxer { private readonly onStreamEnd?: (stream: Stream) => void constructor (init: YamuxMuxerInit) { - this._init = init this.client = init.direction === 'outbound' this.config = { ...defaultConfig, ...init } this.log = this.config.log @@ -89,22 +92,19 @@ export class YamuxMuxer implements StreamMuxer { this._streams = new Map() this.source = pushable({ - onEnd: (err?: Error): void => { + onEnd: (): void => { this.log?.trace('muxer source ended') - this.close(err) + + this._streams.forEach(stream => { + stream.destroy() + }) } }) this.sink = async (source: Source): Promise => { - let signal: ClearableSignal | undefined - - if (this._init.signal != null) { - signal = anySignal([this.closeController.signal, this._init.signal]) - } - source = abortableSource( source, - signal ?? this.closeController.signal, + this.closeController.signal, { returnOnAbort: true } ) @@ -133,15 +133,15 @@ export class YamuxMuxer implements StreamMuxer { } error = err as Error - } finally { - if (signal != null) { - signal.clear() - } } this.log?.trace('muxer sink ended') - this.close(error, reason) + if (error != null) { + this.abort(error, reason) + } else { + await this.close({ reason }) + } } this.numInboundStreams = 0 @@ -261,34 +261,48 @@ export class YamuxMuxer implements StreamMuxer { /** * Close the muxer - * - * @param err - * @param reason - The GoAway reason to be sent */ - close (err?: Error, reason?: GoAwayCode): void { + async close (options: CloseOptions = {}): Promise { if (this.closeController.signal.aborted) { // already closed return } - // If reason was provided, use that, otherwise use the presence of `err` to determine the reason - reason = reason ?? (err === undefined ? GoAwayCode.InternalError : GoAwayCode.NormalTermination) + const reason = options?.reason ?? GoAwayCode.NormalTermination - if (err != null) { - this.log?.error('muxer close reason=%s error=%s', GoAwayCode[reason], err) - } else { - this.log?.trace('muxer close reason=%s', GoAwayCode[reason]) + this.log?.trace('muxer close reason=%s', reason) + + options.signal = options.signal ?? AbortSignal.timeout(CLOSE_TIMEOUT) + + try { + // If err is provided, abort all underlying streams, else close all underlying streams + await Promise.all( + [...this._streams.values()].map(async s => s.close(options)) + ) + + // send reason to the other side, allow the other side to close gracefully + this.sendGoAway(reason) + + this._closeMuxer() + } catch (err: any) { + this.abort(err) } + } - // If err is provided, abort all underlying streams, else close all underlying streams - if (err === undefined) { - for (const stream of this._streams.values()) { - stream.close() - } - } else { - for (const stream of this._streams.values()) { - stream.abort(err) - } + abort (err: Error, reason?: GoAwayCode): void { + if (this.closeController.signal.aborted) { + // already closed + return + } + + reason = reason ?? GoAwayCode.InternalError + + // If reason was provided, use that, otherwise use the presence of `err` to determine the reason + this.log?.error('muxer abort reason=%s error=%s', reason, err) + + // Abort all underlying streams + for (const stream of this._streams.values()) { + stream.abort(err) } // send reason to the other side, allow the other side to close gracefully @@ -319,16 +333,16 @@ export class YamuxMuxer implements StreamMuxer { } const stream = new YamuxStream({ - id, + id: id.toString(), name, state, direction, sendFrame: this.sendFrame.bind(this), - onStreamEnd: () => { + onEnd: () => { this.closeStream(id) this.onStreamEnd?.(stream) }, - log: this.log, + log: logger(`libp2p:yamux:${direction}:${id}`), config: this.config, getRTT: this.getRTT.bind(this) }) diff --git a/packages/stream-multiplexer-yamux/src/stream.ts b/packages/stream-multiplexer-yamux/src/stream.ts index b7f9d4a1d9..20e89094de 100644 --- a/packages/stream-multiplexer-yamux/src/stream.ts +++ b/packages/stream-multiplexer-yamux/src/stream.ts @@ -1,12 +1,10 @@ import { CodeError } from '@libp2p/interface/errors' -import { abortableSource } from 'abortable-iterator' -import { pushable, type Pushable } from 'it-pushable' -import { ERR_RECV_WINDOW_EXCEEDED, ERR_STREAM_ABORT, ERR_STREAM_RESET, INITIAL_STREAM_WINDOW } from './constants.js' +import { AbstractStream, type AbstractStreamInit } from '@libp2p/interface/stream-muxer/stream' +import each from 'it-foreach' +import { ERR_RECV_WINDOW_EXCEEDED, ERR_STREAM_ABORT, INITIAL_STREAM_WINDOW } from './constants.js' import { Flag, type FrameHeader, FrameType, HEADER_LENGTH } from './frame.js' import type { Config } from './config.js' -import type { Direction, Stream, StreamTimeline } from '@libp2p/interface/connection' -import type { Logger } from '@libp2p/logger' -import type { Sink, Source } from 'it-stream-types' +import type { AbortOptions } from '@libp2p/interface' import type { Uint8ArrayList } from 'uint8arraylist' export enum StreamState { @@ -23,25 +21,17 @@ export enum HalfStreamState { Reset, } -export interface YamuxStreamInit { - id: number +export interface YamuxStreamInit extends AbstractStreamInit { name?: string sendFrame: (header: FrameHeader, body?: Uint8Array) => void - onStreamEnd: () => void getRTT: () => number config: Config state: StreamState - log?: Logger - direction: 'inbound' | 'outbound' } /** YamuxStream is used to represent a logical stream within a session */ -export class YamuxStream implements Stream { - id: string +export class YamuxStream extends AbstractStream { name?: string - direction: Direction - timeline: StreamTimeline - metadata: Record state: StreamState /** Used to track received FIN/RST */ @@ -49,15 +39,7 @@ export class YamuxStream implements Stream { /** Used to track sent FIN/RST */ writeState: HalfStreamState - /** Input to the read side of the stream */ - sourceInput: Pushable - /** Read side of the stream */ - source: AsyncGenerator - /** Write side of the stream */ - sink: Sink, Promise> - private readonly config: Config - private readonly log?: Logger private readonly _id: number /** The number of available bytes to send */ @@ -78,23 +60,34 @@ export class YamuxStream implements Stream { private epochStart: number private readonly getRTT: () => number - /** Used to stop the sink */ - private readonly abortController: AbortController - private readonly sendFrame: (header: FrameHeader, body?: Uint8Array) => void - private readonly onStreamEnd: () => void constructor (init: YamuxStreamInit) { + super({ + ...init, + onEnd: (err?: Error) => { + this.state = StreamState.Finished + init.onEnd?.(err) + }, + onCloseRead: () => { + this.readState = HalfStreamState.Closed + }, + onCloseWrite: () => { + this.writeState = HalfStreamState.Closed + }, + onReset: () => { + this.readState = HalfStreamState.Reset + this.writeState = HalfStreamState.Reset + }, + onAbort: () => { + this.readState = HalfStreamState.Reset + this.writeState = HalfStreamState.Reset + } + }) + this.config = init.config - this.log = init.log - this._id = init.id - this.id = String(init.id) + this._id = parseInt(init.id, 10) this.name = init.name - this.direction = init.direction - this.timeline = { - open: Date.now() - } - this.metadata = {} this.state = init.state this.readState = HalfStreamState.Open @@ -106,177 +99,88 @@ export class YamuxStream implements Stream { this.epochStart = Date.now() this.getRTT = init.getRTT - this.abortController = new AbortController() - this.sendFrame = init.sendFrame - this.onStreamEnd = init.onStreamEnd - this.sourceInput = pushable({ - onEnd: (err?: Error) => { - if (err != null) { - this.log?.error('stream source ended id=%s', this._id, err) - } else { - this.log?.trace('stream source ended id=%s', this._id) - } - - this.closeRead() - } + this.source = each(this.source, () => { + this.sendWindowUpdate() }) - - this.source = this.createSource() - - this.sink = async (source: Source): Promise => { - if (this.writeState !== HalfStreamState.Open) { - throw new Error('stream closed for writing') - } - - source = abortableSource(source, this.abortController.signal, { returnOnAbort: true }) - - try { - for await (let data of source) { - // send in chunks, waiting for window updates - while (data.length !== 0) { - // wait for the send window to refill - if (this.sendWindowCapacity === 0) await this.waitForSendWindowCapacity() - - // send as much as we can - const toSend = Math.min(this.sendWindowCapacity, this.config.maxMessageSize - HEADER_LENGTH, data.length) - this.sendData(data.subarray(0, toSend)) - this.sendWindowCapacity -= toSend - data = data.subarray(toSend) - } - } - } catch (e) { - this.log?.error('stream sink error id=%s', this._id, e) - } finally { - this.log?.trace('stream sink ended id=%s', this._id) - this.closeWrite() - } - } } - private async * createSource (): AsyncGenerator { - try { - for await (const val of this.sourceInput) { - this.sendWindowUpdate() - yield val - } - } catch (err) { - const errCode = (err as { code: string }).code - if (errCode !== ERR_STREAM_ABORT) { - this.log?.error('stream source error id=%s', this._id, err) - throw err - } - } - } - - close (): void { - this.log?.trace('stream close id=%s', this._id) - this.closeRead() - this.closeWrite() - } - - closeRead (): void { - if (this.state === StreamState.Finished) { - return - } - - if (this.readState !== HalfStreamState.Open) { - return - } - - this.log?.trace('stream close read id=%s', this._id) - - this.readState = HalfStreamState.Closed - - // close the source - this.sourceInput.end() + /** + * Send a message to the remote muxer informing them a new stream is being + * opened + */ + async sendNewStream (): Promise { - // If the both read and write are closed, finish it - if (this.writeState !== HalfStreamState.Open) { - this.finish() - } } - closeWrite (): void { - if (this.state === StreamState.Finished) { - return - } - - if (this.writeState !== HalfStreamState.Open) { - return - } + /** + * Send a data message to the remote muxer + */ + async sendData (buf: Uint8ArrayList, options: AbortOptions = {}): Promise { + buf = buf.sublist() + + // send in chunks, waiting for window updates + while (buf.byteLength !== 0) { + // wait for the send window to refill + if (this.sendWindowCapacity === 0) { + await this.waitForSendWindowCapacity(options) + } - this.log?.trace('stream close write id=%s', this._id) + // check we didn't close while waiting for send window capacity + if (this.status !== 'open') { + return + } - this.writeState = HalfStreamState.Closed + // send as much as we can + const toSend = Math.min(this.sendWindowCapacity, this.config.maxMessageSize - HEADER_LENGTH, buf.length) + const flags = this.getSendFlags() - this.sendClose() + this.sendFrame({ + type: FrameType.Data, + flag: flags, + streamID: this._id, + length: toSend + }, buf.subarray(0, toSend)) - // close the sink - this.abortController.abort() + this.sendWindowCapacity -= toSend - // If the both read and write are closed, finish it - if (this.readState !== HalfStreamState.Open) { - this.finish() + buf.consume(toSend) } } - abort (err?: Error): void { - switch (this.state) { - case StreamState.Finished: - return - case StreamState.Init: - // we haven't sent anything, so we don't need to send a reset. - break - case StreamState.SYNSent: - case StreamState.SYNReceived: - case StreamState.Established: - // at least one direction is open, we need to send a reset. - this.sendReset() - break - default: - throw new Error('unreachable') - } - - if (err != null) { - this.log?.error('stream abort id=%s error=%s', this._id, err) - } else { - this.log?.trace('stream abort id=%s', this._id) - } - - this.onReset(new CodeError(String(err) ?? 'stream aborted', ERR_STREAM_ABORT)) + /** + * Send a reset message to the remote muxer + */ + async sendReset (): Promise { + this.sendFrame({ + type: FrameType.WindowUpdate, + flag: Flag.RST, + streamID: this._id, + length: 0 + }) } - reset (): void { - if (this.state === StreamState.Finished) { - return - } - - this.log?.trace('stream reset id=%s', this._id) - - this.onReset(new CodeError('stream reset', ERR_STREAM_RESET)) + /** + * Send a message to the remote muxer, informing them no more data messages + * will be sent by this end of the stream + */ + async sendCloseWrite (): Promise { + const flags = this.getSendFlags() | Flag.FIN + this.sendFrame({ + type: FrameType.WindowUpdate, + flag: flags, + streamID: this._id, + length: 0 + }) } /** - * Called when initiating and receiving a stream reset + * Send a message to the remote muxer, informing them no more data messages + * will be read by this end of the stream */ - private onReset (err: Error): void { - // Update stream state to reset / finished - if (this.writeState === HalfStreamState.Open) { - this.writeState = HalfStreamState.Reset - } - if (this.readState === HalfStreamState.Open) { - this.readState = HalfStreamState.Reset - } - this.state = StreamState.Finished + async sendCloseRead (): Promise { - // close both the source and sink - this.sourceInput.end(err) - this.abortController.abort() - - // and finish the stream - this.finish() } /** @@ -284,25 +188,34 @@ export class YamuxStream implements Stream { * * Will throw with ERR_STREAM_ABORT if the stream gets aborted */ - async waitForSendWindowCapacity (): Promise { - if (this.abortController.signal.aborted) { - throw new CodeError('stream aborted', ERR_STREAM_ABORT) - } + async waitForSendWindowCapacity (options: AbortOptions = {}): Promise { if (this.sendWindowCapacity > 0) { return } + + let resolve: () => void let reject: (err: Error) => void const abort = (): void => { - reject(new CodeError('stream aborted', ERR_STREAM_ABORT)) - } - this.abortController.signal.addEventListener('abort', abort) - await new Promise((_resolve, _reject) => { - this.sendWindowCapacityUpdate = () => { - this.abortController.signal.removeEventListener('abort', abort) - _resolve(undefined) + if (this.status === 'open') { + reject(new CodeError('stream aborted', ERR_STREAM_ABORT)) + } else { + // the stream was closed already, ignore the failure to send + resolve() } - reject = _reject - }) + } + options.signal?.addEventListener('abort', abort) + + try { + await new Promise((_resolve, _reject) => { + this.sendWindowCapacityUpdate = () => { + _resolve() + } + reject = _reject + resolve = _resolve + }) + } finally { + options.signal?.removeEventListener('abort', abort) + } } /** @@ -335,7 +248,8 @@ export class YamuxStream implements Stream { const data = await readData() this.recvWindowCapacity -= header.length - this.sourceInput.push(data) + + this.sourcePush(data) } /** @@ -348,23 +262,13 @@ export class YamuxStream implements Stream { } } if ((flags & Flag.FIN) === Flag.FIN) { - this.closeRead() + this.remoteCloseWrite() } if ((flags & Flag.RST) === Flag.RST) { this.reset() } } - /** - * finish sets the state and triggers eventual garbage collection of the stream - */ - private finish (): void { - this.log?.trace('stream finished id=%s', this._id) - this.state = StreamState.Finished - this.timeline.close = Date.now() - this.onStreamEnd() - } - /** * getSendFlags determines any flags that are appropriate * based on the current stream state. @@ -421,33 +325,4 @@ export class YamuxStream implements Stream { length: delta }) } - - private sendData (data: Uint8Array): void { - const flags = this.getSendFlags() - this.sendFrame({ - type: FrameType.Data, - flag: flags, - streamID: this._id, - length: data.length - }, data) - } - - private sendClose (): void { - const flags = this.getSendFlags() | Flag.FIN - this.sendFrame({ - type: FrameType.WindowUpdate, - flag: flags, - streamID: this._id, - length: 0 - }) - } - - private sendReset (): void { - this.sendFrame({ - type: FrameType.WindowUpdate, - flag: Flag.RST, - streamID: this._id, - length: 0 - }) - } } diff --git a/packages/stream-multiplexer-yamux/test/bench/comparison.bench.ts b/packages/stream-multiplexer-yamux/test/bench/comparison.bench.ts index 6b05bb8158..c601b3006c 100644 --- a/packages/stream-multiplexer-yamux/test/bench/comparison.bench.ts +++ b/packages/stream-multiplexer-yamux/test/bench/comparison.bench.ts @@ -23,7 +23,7 @@ describe('comparison benchmark', () => { id: `${name} send and receive ${numMessages} ${msgSize / 1024}KB chunks`, beforeEach: () => impl({ onIncomingStream: (stream) => { - void pipe(stream, drain).then(() => { stream.close() }) + void pipe(stream, drain).then(async () => { await stream.close() }) } }), fn: async ({ client, server }) => { diff --git a/packages/stream-multiplexer-yamux/test/muxer.spec.ts b/packages/stream-multiplexer-yamux/test/muxer.spec.ts index 3827e91387..017125865c 100644 --- a/packages/stream-multiplexer-yamux/test/muxer.spec.ts +++ b/packages/stream-multiplexer-yamux/test/muxer.spec.ts @@ -4,15 +4,28 @@ import { expect } from 'aegir/chai' import { duplexPair } from 'it-pair/duplex' import { pipe } from 'it-pipe' import { ERR_MUXER_LOCAL_CLOSED } from '../src/constants.js' -import { sleep, testClientServer, testYamuxMuxer } from './util.js' +import { sleep, testClientServer, testYamuxMuxer, type YamuxFixture } from './util.js' describe('muxer', () => { + let client: YamuxFixture + let server: YamuxFixture + + afterEach(async () => { + if (client != null) { + await client.close() + } + + if (server != null) { + await server.close() + } + }) + it('test repeated close', async () => { const client1 = testYamuxMuxer('libp2p:yamux:1', true) // inspect logs to ensure its only closed once - client1.close() - client1.close() - client1.close() + await client1.close() + await client1.close() + await client1.close() }) it('test client<->client', async () => { @@ -46,7 +59,7 @@ describe('muxer', () => { }) it('test ping', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) server.pauseRead() const clientRTT = client.ping() @@ -59,13 +72,10 @@ describe('muxer', () => { await sleep(10) server.unpauseWrite() expect(await serverRTT).to.not.equal(0) - - client.close() - server.close() }) it('test multiple simultaneous pings', async () => { - const { client } = testClientServer() + ({ client, server } = testClientServer()) client.pauseWrite() const promise = [ @@ -84,32 +94,32 @@ describe('muxer', () => { // eslint-disable-next-line @typescript-eslint/dot-notation expect(client['nextPingID']).to.equal(1) - client.close() + await client.close() }) - it('test go away', () => { - const { client } = testClientServer() - client.close() - try { + it('test go away', async () => { + ({ client, server } = testClientServer()) + await client.close() + + expect(() => { client.newStream() - expect.fail('should not be able to open a stream after close') - } catch (e) { - expect((e as { code: string }).code).to.equal(ERR_MUXER_LOCAL_CLOSED) - } + }).to.throw().with.property('code', ERR_MUXER_LOCAL_CLOSED, 'should not be able to open a stream after close') }) it('test keep alive', async () => { - const { client } = testClientServer({ enableKeepAlive: true, keepAliveInterval: 10 }) + ({ client, server } = testClientServer({ enableKeepAlive: true, keepAliveInterval: 10 })) - await sleep(100) + await sleep(1000) // eslint-disable-next-line @typescript-eslint/dot-notation expect(client['nextPingID']).to.be.gt(2) - client.close() + await client.close() + await server.close() }) it('test max inbound streams', async () => { - const { client, server } = testClientServer({ maxInboundStreams: 1 }) + ({ client, server } = testClientServer({ maxInboundStreams: 1 })) + client.newStream() client.newStream() await sleep(10) @@ -119,7 +129,8 @@ describe('muxer', () => { }) it('test max outbound streams', async () => { - const { client, server } = testClientServer({ maxOutboundStreams: 1 }) + ({ client, server } = testClientServer({ maxOutboundStreams: 1 })) + client.newStream() await sleep(10) diff --git a/packages/stream-multiplexer-yamux/test/stream.spec.ts b/packages/stream-multiplexer-yamux/test/stream.spec.ts index ee5aa23c8c..b32e761d20 100644 --- a/packages/stream-multiplexer-yamux/test/stream.spec.ts +++ b/packages/stream-multiplexer-yamux/test/stream.spec.ts @@ -4,14 +4,28 @@ import { expect } from 'aegir/chai' import { pipe } from 'it-pipe' import { type Pushable, pushable } from 'it-pushable' import { defaultConfig } from '../src/config.js' -import { ERR_STREAM_RESET } from '../src/constants.js' +import { ERR_RECV_WINDOW_EXCEEDED } from '../src/constants.js' import { GoAwayCode } from '../src/frame.js' import { HalfStreamState, StreamState } from '../src/stream.js' -import { sleep, testClientServer } from './util.js' +import { sleep, testClientServer, type YamuxFixture } from './util.js' +import type { Uint8ArrayList } from 'uint8arraylist' describe('stream', () => { + let client: YamuxFixture + let server: YamuxFixture + + afterEach(async () => { + if (client != null) { + await client.close() + } + + if (server != null) { + await server.close() + } + }) + it('test send data - small', async () => { - const { client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize }) + ({ client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize })) const { default: drain } = await import('it-drain') const p = pushable() @@ -37,7 +51,7 @@ describe('stream', () => { }) it('test send data - large', async () => { - const { client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize }) + ({ client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize })) const { default: drain } = await import('it-drain') const p = pushable() @@ -65,7 +79,7 @@ describe('stream', () => { }) it('test send data - large with increasing recv window size', async () => { - const { client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize }) + ({ client, server } = testClientServer({ initialStreamWindowSize: defaultConfig.initialStreamWindowSize })) const { default: drain } = await import('it-drain') const p = pushable() @@ -97,7 +111,7 @@ describe('stream', () => { }) it('test many streams', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) for (let i = 0; i < 1000; i++) { client.newStream() } @@ -108,11 +122,11 @@ describe('stream', () => { }) it('test many streams - ping pong', async () => { - const numStreams = 10 - const { client, server } = testClientServer({ + ({ client, server } = testClientServer({ // echo on incoming streams onIncomingStream: (stream) => { void pipe(stream, stream) } - }) + })) + const numStreams = 10 const p: Array> = [] for (let i = 0; i < numStreams; i++) { @@ -131,14 +145,14 @@ describe('stream', () => { expect(client.streams.length).to.equal(numStreams) expect(server.streams.length).to.equal(numStreams) - client.close() + await client.close() }) it('test stream close', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) const c1 = client.newStream() - c1.close() + await c1.close() await sleep(5) expect(c1.state).to.equal(StreamState.Finished) @@ -149,10 +163,10 @@ describe('stream', () => { }) it('test stream close read', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) const c1 = client.newStream() - c1.closeRead() + await c1.closeRead() await sleep(5) expect(c1.readState).to.equal(HalfStreamState.Closed) @@ -165,10 +179,10 @@ describe('stream', () => { }) it('test stream close write', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) const c1 = client.newStream() - c1.closeWrite() + await c1.closeWrite() await sleep(5) expect(c1.readState).to.equal(HalfStreamState.Open) @@ -181,7 +195,7 @@ describe('stream', () => { }) it('test window overflow', async () => { - const { client, server } = testClientServer({ maxMessageSize: defaultConfig.initialStreamWindowSize, initialStreamWindowSize: defaultConfig.initialStreamWindowSize }) + ({ client, server } = testClientServer({ maxMessageSize: defaultConfig.initialStreamWindowSize, initialStreamWindowSize: defaultConfig.initialStreamWindowSize })) const { default: drain } = await import('it-drain') const p = pushable() @@ -191,12 +205,10 @@ describe('stream', () => { const s1 = server.streams[0] const sendPipe = pipe(p, c1) - // eslint-disable-next-line @typescript-eslint/dot-notation - const c1SendData = c1['sendData'].bind(c1) - // eslint-disable-next-line @typescript-eslint/dot-notation - ;(c1 as any)['sendData'] = (data: Uint8Array): void => { - // eslint-disable-next-line @typescript-eslint/dot-notation - c1SendData(data) + const c1SendData = c1.sendData.bind(c1) + + c1.sendData = async (data: Uint8ArrayList): Promise => { + await c1SendData(data) // eslint-disable-next-line @typescript-eslint/dot-notation c1['sendWindowCapacity'] = defaultConfig.initialStreamWindowSize * 10 } @@ -211,16 +223,15 @@ describe('stream', () => { try { await Promise.all([sendPipe, recvPipe]) } catch (e) { - expect((e as { code: string }).code).to.equal(ERR_STREAM_RESET) + expect((e as { code: string }).code).to.equal(ERR_RECV_WINDOW_EXCEEDED) } - // eslint-disable-next-line @typescript-eslint/dot-notation - expect(client['remoteGoAway']).to.equal(GoAwayCode.ProtocolError) - // eslint-disable-next-line @typescript-eslint/dot-notation - expect(server['localGoAway']).to.equal(GoAwayCode.ProtocolError) + + expect(client).to.have.property('remoteGoAway', GoAwayCode.ProtocolError) + expect(server).to.have.property('localGoAway', GoAwayCode.ProtocolError) }) it('test stream sink error', async () => { - const { client, server } = testClientServer() + ({ client, server } = testClientServer()) // don't let the server respond server.pauseRead() @@ -237,7 +248,7 @@ describe('stream', () => { await sleep(10) // the client should close gracefully even though it was waiting to send more data - client.close() + await client.close() p.end() await sendPipe diff --git a/packages/stream-multiplexer-yamux/test/util.ts b/packages/stream-multiplexer-yamux/test/util.ts index e71b14b42b..3c28a7ff76 100644 --- a/packages/stream-multiplexer-yamux/test/util.ts +++ b/packages/stream-multiplexer-yamux/test/util.ts @@ -68,19 +68,16 @@ export function pauseableTransform (): { transform: Transform, Asy return { transform, pause, unpause } } +export interface YamuxFixture extends YamuxMuxer { + pauseRead: () => void + unpauseRead: () => void + pauseWrite: () => void + unpauseWrite: () => void +} + export function testClientServer (conf: YamuxMuxerInit = {}): { - client: YamuxMuxer & { - pauseRead: () => void - unpauseRead: () => void - pauseWrite: () => void - unpauseWrite: () => void - } - server: YamuxMuxer & { - pauseRead: () => void - unpauseRead: () => void - pauseWrite: () => void - unpauseWrite: () => void - } + client: YamuxFixture + server: YamuxFixture } { const pair = duplexPair() const client = testYamuxMuxer('libp2p:yamux:client', true, conf) diff --git a/packages/transport-tcp/src/constants.ts b/packages/transport-tcp/src/constants.ts index 8402e3a703..6501500744 100644 --- a/packages/transport-tcp/src/constants.ts +++ b/packages/transport-tcp/src/constants.ts @@ -4,7 +4,7 @@ export const CODE_CIRCUIT = 290 export const CODE_UNIX = 400 // Time to wait for a connection to close gracefully before destroying it manually -export const CLOSE_TIMEOUT = 2000 +export const CLOSE_TIMEOUT = 500 // Close the socket if there is no activity after this long in ms export const SOCKET_TIMEOUT = 5 * 60000 // 5 mins diff --git a/packages/transport-tcp/src/socket-to-conn.ts b/packages/transport-tcp/src/socket-to-conn.ts index e08d2dbecb..fba2708330 100644 --- a/packages/transport-tcp/src/socket-to-conn.ts +++ b/packages/transport-tcp/src/socket-to-conn.ts @@ -7,7 +7,7 @@ import { CLOSE_TIMEOUT, SOCKET_TIMEOUT } from './constants.js' import { multiaddrToNetConfig } from './utils.js' import type { MultiaddrConnection } from '@libp2p/interface/connection' import type { CounterGroup } from '@libp2p/interface/metrics' -import type { Multiaddr } from '@multiformats/multiaddr' +import type { AbortOptions, Multiaddr } from '@multiformats/multiaddr' import type { Socket } from 'net' const log = logger('libp2p:tcp:socket') @@ -120,75 +120,61 @@ export const toMultiaddrConnection = (socket: Socket, options: ToConnectionOptio timeline: { open: Date.now() }, - async close () { + async close (options: AbortOptions = {}) { if (socket.destroyed) { log('%s socket was already destroyed when trying to close', lOptsStr) return } - log('%s closing socket', lOptsStr) - await new Promise((resolve, reject) => { - const start = Date.now() + options.signal = options.signal ?? AbortSignal.timeout(closeTimeout) - let timeout: NodeJS.Timeout | undefined - - socket.once('close', () => { - log('%s socket closed', lOptsStr) - // socket completely closed - if (timeout !== undefined) { - clearTimeout(timeout) - } - resolve() - }) - socket.once('error', (err: Error) => { - log('%s socket error', lOptsStr, err) - - // error closing socket - if (maConn.timeline.close == null) { - maConn.timeline.close = Date.now() - } + try { + log('%s closing socket', lOptsStr) + await new Promise((resolve, reject) => { + socket.once('close', () => { + // socket completely closed + log('%s socket closed', lOptsStr) + resolve() + }) + socket.once('error', (err: Error) => { + log('%s socket error', lOptsStr, err) - if (socket.destroyed) { - if (timeout !== undefined) { - clearTimeout(timeout) + // error closing socket + if (maConn.timeline.close == null) { + maConn.timeline.close = Date.now() } - } - reject(err) - }) + reject(err) + }) - // shorten inactivity timeout - socket.setTimeout(closeTimeout) + // shorten inactivity timeout + socket.setTimeout(closeTimeout) - // close writable end of the socket - socket.end() + // close writable end of the socket + socket.end() - if (socket.writableLength > 0) { - // Attempt to end the socket. If it takes longer to close than the - // timeout, destroy it manually. - timeout = setTimeout(() => { - if (socket.destroyed) { - log('%s is already destroyed', lOptsStr) - resolve() - } else { - log('%s socket close timeout after %dms, destroying it manually', lOptsStr, Date.now() - start) + if (socket.writableLength > 0) { + // there are outgoing bytes waiting to be sent + socket.once('drain', () => { + log('%s socket drained', lOptsStr) - // will trigger 'error' and 'close' events that resolves promise - socket.destroy(new CodeError('Socket close timeout', 'ERR_SOCKET_CLOSE_TIMEOUT')) - } - }, closeTimeout).unref() - // there are outgoing bytes waiting to be sent - socket.once('drain', () => { - log('%s socket drained', lOptsStr) - - // all bytes have been sent we can destroy the socket (maybe) before the timeout + // all bytes have been sent we can destroy the socket (maybe) before the timeout + socket.destroy() + }) + } else { + // nothing to send, destroy immediately, no need for the timeout socket.destroy() - }) - } else { - // nothing to send, destroy immediately, no need the timeout - socket.destroy() - } - }) + } + }) + } catch (err: any) { + this.abort(err) + } + }, + + abort: (err: Error) => { + log('%s socket abort due to error', lOptsStr, err) + + socket.destroy(err) } } diff --git a/packages/transport-tcp/test/socket-to-conn.spec.ts b/packages/transport-tcp/test/socket-to-conn.spec.ts index e4f8dbfd49..2a97f962fe 100644 --- a/packages/transport-tcp/test/socket-to-conn.spec.ts +++ b/packages/transport-tcp/test/socket-to-conn.spec.ts @@ -372,8 +372,6 @@ describe('socket-to-conn', () => { // promise that is resolved when our outgoing socket errors const serverErrored = defer() - let maConnCloseError: Error | undefined - const inboundMaConn = toMultiaddrConnection(serverSocket, { socketInactivityTimeout: 100, socketCloseTimeout: 100 @@ -402,10 +400,7 @@ describe('socket-to-conn', () => { serverSocket.write('goodbyeeeeeeeeeeeeee') } - await inboundMaConn.close().catch(err => { - // should throw this error - maConnCloseError = err - }) + await inboundMaConn.close() // server socket should no longer be writable expect(serverSocket.writable).to.be.false() @@ -414,10 +409,7 @@ describe('socket-to-conn', () => { await expect(serverClosed.promise).to.eventually.be.true() // remote didn't read our data - await expect(serverErrored.promise).to.eventually.have.property('code', 'ERR_SOCKET_CLOSE_TIMEOUT') - - // closing should have thrown - expect(maConnCloseError).to.have.property('code', 'ERR_SOCKET_CLOSE_TIMEOUT') + await expect(serverErrored.promise).to.eventually.have.property('code', 'ERR_SOCKET_READ_TIMEOUT') // the connection closing was recorded expect(inboundMaConn.timeline.close).to.be.a('number') diff --git a/packages/transport-webrtc/examples/browser-to-browser/package.json b/packages/transport-webrtc/examples/browser-to-browser/package.json index 94368cbe64..cfabb369fa 100644 --- a/packages/transport-webrtc/examples/browser-to-browser/package.json +++ b/packages/transport-webrtc/examples/browser-to-browser/package.json @@ -17,7 +17,7 @@ "@libp2p/mplex": "^8.0.1", "@libp2p/webrtc": "file:../../", "@multiformats/multiaddr": "^12.0.0", - "it-pushable": "^3.1.0", + "it-pushable": "^3.2.0", "libp2p": "^0.45.0", "vite": "^4.2.1" }, diff --git a/packages/transport-webrtc/examples/browser-to-server/package.json b/packages/transport-webrtc/examples/browser-to-server/package.json index 22e88e114e..0b084031a7 100644 --- a/packages/transport-webrtc/examples/browser-to-server/package.json +++ b/packages/transport-webrtc/examples/browser-to-server/package.json @@ -15,7 +15,7 @@ "@chainsafe/libp2p-noise": "^12.0.0", "@libp2p/webrtc": "file:../../", "@multiformats/multiaddr": "^12.0.0", - "it-pushable": "^3.1.0", + "it-pushable": "^3.2.0", "libp2p": "^0.45.0", "vite": "^4.2.1" }, diff --git a/packages/transport-webrtc/package.json b/packages/transport-webrtc/package.json index 824bca6152..c0cffcc1d6 100644 --- a/packages/transport-webrtc/package.json +++ b/packages/transport-webrtc/package.json @@ -54,9 +54,9 @@ "abortable-iterator": "^5.0.1", "detect-browser": "^5.3.0", "it-length-prefixed": "^9.0.1", - "it-pb-stream": "^4.0.1", + "it-protobuf-stream": "^1.0.0", "it-pipe": "^3.0.1", - "it-pushable": "^3.1.3", + "it-pushable": "^3.2.0", "it-stream-types": "^2.0.1", "it-to-buffer": "^4.0.2", "multiformats": "^12.0.1", diff --git a/packages/transport-webrtc/src/maconn.ts b/packages/transport-webrtc/src/maconn.ts index 2281db9fa9..c32d88760c 100644 --- a/packages/transport-webrtc/src/maconn.ts +++ b/packages/transport-webrtc/src/maconn.ts @@ -2,7 +2,7 @@ import { logger } from '@libp2p/logger' import { nopSink, nopSource } from './util.js' import type { MultiaddrConnection, MultiaddrConnectionTimeline } from '@libp2p/interface/connection' import type { CounterGroup } from '@libp2p/interface/metrics' -import type { Multiaddr } from '@multiformats/multiaddr' +import type { AbortOptions, Multiaddr } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' const log = logger('libp2p:webrtc:connection') @@ -72,14 +72,19 @@ export class WebRTCMultiaddrConnection implements MultiaddrConnection { } } - async close (err?: Error | undefined): Promise { - if (err !== undefined) { - log.error('error closing connection', err) - } + async close (options?: AbortOptions): Promise { log.trace('closing connection') - this.timeline.close = Date.now() this.peerConnection.close() + this.timeline.close = Date.now() this.metrics?.increment({ close: true }) } + + abort (err: Error): void { + log.error('closing connection due to error', err) + + this.peerConnection.close() + this.timeline.close = Date.now() + this.metrics?.increment({ abort: true }) + } } diff --git a/packages/transport-webrtc/src/muxer.ts b/packages/transport-webrtc/src/muxer.ts index 93825a923b..7e5b655a3c 100644 --- a/packages/transport-webrtc/src/muxer.ts +++ b/packages/transport-webrtc/src/muxer.ts @@ -4,6 +4,7 @@ import type { DataChannelOpts } from './stream.js' import type { Stream } from '@libp2p/interface/connection' import type { CounterGroup } from '@libp2p/interface/metrics' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface/stream-muxer' +import type { AbortOptions } from '@multiformats/multiaddr' import type { Source, Sink } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' @@ -93,9 +94,14 @@ export class DataChannelMuxer implements StreamMuxer { private readonly metrics?: CounterGroup /** - * Close or abort all tracked streams and stop the muxer + * Gracefully close all tracked streams and stop the muxer */ - close: (err?: Error | undefined) => void = () => { } + close: (options?: AbortOptions) => Promise = async () => { } + + /** + * Abort all tracked streams and stop the muxer + */ + abort: (err: Error) => void = () => { } /** * The stream source, a no-op as the transport natively supports multiplexing diff --git a/packages/transport-webrtc/src/private-to-private/handler.ts b/packages/transport-webrtc/src/private-to-private/handler.ts index 9818a224b8..2b0e09e19d 100644 --- a/packages/transport-webrtc/src/private-to-private/handler.ts +++ b/packages/transport-webrtc/src/private-to-private/handler.ts @@ -1,6 +1,6 @@ import { logger } from '@libp2p/logger' import { abortableDuplex } from 'abortable-iterator' -import { pbStream } from 'it-pb-stream' +import { pbStream } from 'it-protobuf-stream' import pDefer, { type DeferredPromise } from 'p-defer' import { DataChannelMuxerFactory } from '../muxer.js' import { Message } from './pb/message.js' @@ -28,8 +28,8 @@ export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptio // candidate callbacks pc.onicecandidate = ({ candidate }) => { answerSentPromise.promise.then( - () => { - stream.write({ + async () => { + await stream.write({ type: Message.Type.ICE_CANDIDATE, data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' }) @@ -64,7 +64,7 @@ export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptio throw new Error('Failed to create answer') }) // write the answer to the remote - stream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }) + await stream.write({ type: Message.Type.SDP_ANSWER, data: answer.sdp }) await pc.setLocalDescription(answer).catch(err => { log.error('could not execute setLocalDescription', err) @@ -107,15 +107,18 @@ export async function initiateConnection ({ rtcConfiguration, dataChannelOptions // setup callback to write ICE candidates to the remote // peer pc.onicecandidate = ({ candidate }) => { - stream.write({ + void stream.write({ type: Message.Type.ICE_CANDIDATE, data: (candidate != null) ? JSON.stringify(candidate.toJSON()) : '' }) + .catch(err => { + log.error('error sending ICE candidate', err) + }) } // create an offer const offerSdp = await pc.createOffer() // write the offer to the stream - stream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }) + await stream.write({ type: Message.Type.SDP_OFFER, data: offerSdp.sdp }) // set offer as local description await pc.setLocalDescription(offerSdp).catch(err => { log.error('could not execute setLocalDescription', err) diff --git a/packages/transport-webrtc/src/private-to-private/transport.ts b/packages/transport-webrtc/src/private-to-private/transport.ts index ca24931930..67b5b38abf 100644 --- a/packages/transport-webrtc/src/private-to-private/transport.ts +++ b/packages/transport-webrtc/src/private-to-private/transport.ts @@ -114,11 +114,11 @@ export class WebRTCTransport implements Transport, Startable { ) // close the stream if SDP has been exchanged successfully - signalingStream.close() + await signalingStream.close() return result - } catch (err) { + } catch (err: any) { // reset the stream in case of any error - signalingStream.reset() + signalingStream.abort(err) throw err } finally { // Close the signaling connection @@ -144,8 +144,8 @@ export class WebRTCTransport implements Transport, Startable { skipProtection: true, muxerFactory }) - } catch (err) { - stream.reset() + } catch (err: any) { + stream.abort(err) throw err } finally { // Close the signaling connection diff --git a/packages/transport-webrtc/src/stream.ts b/packages/transport-webrtc/src/stream.ts index bfed4fc88c..7b26bf7ac1 100644 --- a/packages/transport-webrtc/src/stream.ts +++ b/packages/transport-webrtc/src/stream.ts @@ -6,7 +6,7 @@ import { type Pushable, pushable } from 'it-pushable' import { pEvent, TimeoutError } from 'p-event' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from './pb/message.js' -import type { Direction, Stream } from '@libp2p/interface/connection' +import type { Direction } from '@libp2p/interface/connection' const log = logger('libp2p:webrtc:stream') @@ -26,6 +26,8 @@ export interface WebRTCStreamInit extends AbstractStreamInit { channel: RTCDataChannel dataChannelOptions?: Partial + + maxDataSize: number } // Max message size that can be sent to the DataChannel @@ -40,7 +42,7 @@ const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000 // protobuf field definition overhead const PROTOBUF_OVERHEAD = 3 -class WebRTCStream extends AbstractStream { +export class WebRTCStream extends AbstractStream { /** * The data channel used to send and receive data */ @@ -58,6 +60,7 @@ class WebRTCStream extends AbstractStream { private readonly incomingData: Pushable private messageQueue?: Uint8ArrayList + private readonly maxDataSize: number constructor (init: WebRTCStreamInit) { super(init) @@ -71,6 +74,7 @@ class WebRTCStream extends AbstractStream { maxBufferedAmount: init.dataChannelOptions?.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT, maxMessageSize: init.dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE } + this.maxDataSize = init.maxDataSize // set up initial state switch (this.channel.readyState) { @@ -107,7 +111,9 @@ class WebRTCStream extends AbstractStream { } this.channel.onclose = (_evt) => { - this.close() + void this.close().catch(err => { + log.error('error closing stream after channel closed', err) + }) } this.channel.onerror = (evt) => { @@ -153,7 +159,6 @@ class WebRTCStream extends AbstractStream { await pEvent(this.channel, 'bufferedamountlow', { timeout: this.dataChannelOptions.bufferedAmountLowEventTimeout }) } catch (err: any) { if (err instanceof TimeoutError) { - this.abort(err) throw new Error('Timed out waiting for DataChannel buffer to clear') } @@ -184,10 +189,17 @@ class WebRTCStream extends AbstractStream { } async sendData (data: Uint8ArrayList): Promise { - const msgbuf = Message.encode({ message: data.subarray() }) - const sendbuf = lengthPrefixed.encode.single(msgbuf) + data = data.sublist() - await this._sendMessage(sendbuf) + while (data.byteLength > 0) { + const toSend = Math.min(data.byteLength, this.maxDataSize) + const buf = data.subarray(0, toSend) + const msgbuf = Message.encode({ message: buf }) + const sendbuf = lengthPrefixed.encode.single(msgbuf) + await this._sendMessage(sendbuf) + + data.consume(toSend) + } } async sendReset (): Promise { @@ -212,7 +224,7 @@ class WebRTCStream extends AbstractStream { if (message.flag === Message.Flag.FIN) { // We should expect no more data from the remote, stop reading this.incomingData.end() - this.closeRead() + this.remoteCloseWrite() } if (message.flag === Message.Flag.RESET) { @@ -222,7 +234,7 @@ class WebRTCStream extends AbstractStream { if (message.flag === Message.Flag.STOP_SENDING) { // The remote has stopped reading - this.closeWrite() + this.remoteCloseRead() } } @@ -259,7 +271,7 @@ export interface WebRTCStreamOptions { onEnd?: (err?: Error | undefined) => void } -export function createStream (options: WebRTCStreamOptions): Stream { +export function createStream (options: WebRTCStreamOptions): WebRTCStream { const { channel, direction, onEnd, dataChannelOptions } = options return new WebRTCStream({ @@ -268,6 +280,7 @@ export function createStream (options: WebRTCStreamOptions): Stream { maxDataSize: (dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE) - PROTOBUF_OVERHEAD, dataChannelOptions, onEnd, - channel + channel, + log: logger(`libp2p:mplex:stream:${direction}:${channel.id}`) }) } diff --git a/packages/transport-webrtc/test/peer.browser.spec.ts b/packages/transport-webrtc/test/peer.browser.spec.ts index ae3b79d325..918ed2d0a5 100644 --- a/packages/transport-webrtc/test/peer.browser.spec.ts +++ b/packages/transport-webrtc/test/peer.browser.spec.ts @@ -5,7 +5,7 @@ import { expect } from 'aegir/chai' import { detect } from 'detect-browser' import { pair } from 'it-pair' import { duplexPair } from 'it-pair/duplex' -import { pbStream } from 'it-pb-stream' +import { pbStream } from 'it-protobuf-stream' import Sinon from 'sinon' import { initiateConnection, handleIncomingStream } from '../src/private-to-private/handler' import { Message } from '../src/private-to-private/pb/message.js' @@ -47,7 +47,7 @@ describe('webrtc receiver', () => { const receiverPeerConnectionPromise = handleIncomingStream({ stream: mockStream(receiver), connection }) const stream = pbStream(initiator).pb(Message) - stream.write({ type: Message.Type.SDP_OFFER, data: 'bad' }) + await stream.write({ type: Message.Type.SDP_OFFER, data: 'bad' }) await expect(receiverPeerConnectionPromise).to.be.rejectedWith(/Failed to set remoteDescription/) }) }) @@ -64,7 +64,7 @@ describe('webrtc dialer', () => { expect(offerMessage.type).to.eq(Message.Type.SDP_OFFER) } - stream.write({ type: Message.Type.SDP_ANSWER, data: 'bad' }) + await stream.write({ type: Message.Type.SDP_ANSWER, data: 'bad' }) await expect(initiatorPeerConnectionPromise).to.be.rejectedWith(/Failed to set remoteDescription/) }) @@ -76,7 +76,7 @@ describe('webrtc dialer', () => { const pc = new RTCPeerConnection() pc.onicecandidate = ({ candidate }) => { - stream.write({ type: Message.Type.ICE_CANDIDATE, data: JSON.stringify(candidate?.toJSON()) }) + void stream.write({ type: Message.Type.ICE_CANDIDATE, data: JSON.stringify(candidate?.toJSON()) }) } { const offerMessage = await stream.read() diff --git a/packages/transport-webrtc/test/stream.browser.spec.ts b/packages/transport-webrtc/test/stream.browser.spec.ts index 785c7d4f9b..9bf0172c5e 100644 --- a/packages/transport-webrtc/test/stream.browser.spec.ts +++ b/packages/transport-webrtc/test/stream.browser.spec.ts @@ -3,11 +3,11 @@ import delay from 'delay' import * as lengthPrefixed from 'it-length-prefixed' import { bytes } from 'multiformats' import { Message } from '../src/pb/message.js' -import { createStream } from '../src/stream' +import { createStream, type WebRTCStream } from '../src/stream.js' import type { Stream } from '@libp2p/interface/connection' const TEST_MESSAGE = 'test_message' -function setup (): { peerConnection: RTCPeerConnection, dataChannel: RTCDataChannel, stream: Stream } { +function setup (): { peerConnection: RTCPeerConnection, dataChannel: RTCDataChannel, stream: WebRTCStream } { const peerConnection = new RTCPeerConnection() const dataChannel = peerConnection.createDataChannel('whatever', { negotiated: true, id: 91 }) const stream = createStream({ channel: dataChannel, direction: 'outbound' }) @@ -25,7 +25,7 @@ function generatePbByFlag (flag?: Message.Flag): Uint8Array { } describe('Stream Stats', () => { - let stream: Stream + let stream: WebRTCStream beforeEach(async () => { ({ stream } = setup()) @@ -35,30 +35,30 @@ describe('Stream Stats', () => { expect(stream.timeline.close).to.not.exist() }) - it('close marks it closed', () => { + it('close marks it closed', async () => { expect(stream.timeline.close).to.not.exist() - stream.close() + await stream.close() expect(stream.timeline.close).to.be.a('number') }) - it('closeRead marks it read-closed only', () => { + it('closeRead marks it read-closed only', async () => { expect(stream.timeline.close).to.not.exist() - stream.closeRead() + await stream.closeRead() expect(stream.timeline.close).to.not.exist() expect(stream.timeline.closeRead).to.be.greaterThanOrEqual(stream.timeline.open) }) - it('closeWrite marks it write-closed only', () => { + it('closeWrite marks it write-closed only', async () => { expect(stream.timeline.close).to.not.exist() - stream.closeWrite() + await stream.closeWrite() expect(stream.timeline.close).to.not.exist() expect(stream.timeline.closeWrite).to.be.greaterThanOrEqual(stream.timeline.open) }) it('closeWrite AND closeRead = close', async () => { expect(stream.timeline.close).to.not.exist() - stream.closeWrite() - stream.closeRead() + await stream.closeWrite() + await stream.closeRead() expect(stream.timeline.close).to.be.a('number') expect(stream.timeline.closeWrite).to.be.greaterThanOrEqual(stream.timeline.open) expect(stream.timeline.closeRead).to.be.greaterThanOrEqual(stream.timeline.open) diff --git a/packages/transport-webrtc/test/stream.spec.ts b/packages/transport-webrtc/test/stream.spec.ts index dd44b1ebab..1812ca2c7a 100644 --- a/packages/transport-webrtc/test/stream.spec.ts +++ b/packages/transport-webrtc/test/stream.spec.ts @@ -99,17 +99,15 @@ describe('Max message size', () => { } }) - const p = pushable() - p.push(new Uint8Array(1)) - p.end() - const t0 = Date.now() - await expect(webrtcStream.sink(p)).to.eventually.be.rejected + await expect(webrtcStream.sink([new Uint8Array(1)])).to.eventually.be.rejected .with.property('message', 'Timed out waiting for DataChannel buffer to clear') const t1 = Date.now() expect(t1 - t0).greaterThan(timeout) expect(t1 - t0).lessThan(timeout + 1000) // Some upper bound expect(closed).true() + expect(webrtcStream.timeline.close).to.be.greaterThan(webrtcStream.timeline.open) + expect(webrtcStream.timeline.abort).to.be.greaterThan(webrtcStream.timeline.open) }) }) diff --git a/packages/transport-websockets/package.json b/packages/transport-websockets/package.json index 37e23fc6e1..176302c318 100644 --- a/packages/transport-websockets/package.json +++ b/packages/transport-websockets/package.json @@ -78,7 +78,6 @@ "abortable-iterator": "^5.0.1", "it-ws": "^6.0.0", "p-defer": "^4.0.0", - "p-timeout": "^6.0.0", "wherearewe": "^2.0.1", "ws": "^8.12.1" }, diff --git a/packages/transport-websockets/src/constants.ts b/packages/transport-websockets/src/constants.ts index e8c3939e1d..93b937c0d3 100644 --- a/packages/transport-websockets/src/constants.ts +++ b/packages/transport-websockets/src/constants.ts @@ -7,4 +7,4 @@ export const CODE_WS = 477 export const CODE_WSS = 478 // Time to wait for a connection to close gracefully before destroying it manually -export const CLOSE_TIMEOUT = 2000 +export const CLOSE_TIMEOUT = 500 diff --git a/packages/transport-websockets/src/socket-to-conn.ts b/packages/transport-websockets/src/socket-to-conn.ts index b7de27dd63..206ad7459b 100644 --- a/packages/transport-websockets/src/socket-to-conn.ts +++ b/packages/transport-websockets/src/socket-to-conn.ts @@ -1,6 +1,6 @@ +import { CodeError } from '@libp2p/interface/errors' import { logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import pTimeout from 'p-timeout' import { CLOSE_TIMEOUT } from './constants.js' import type { AbortOptions } from '@libp2p/interface' import type { MultiaddrConnection } from '@libp2p/interface/connection' @@ -39,22 +39,37 @@ export function socketToMaConn (stream: DuplexWebSocket, remoteAddr: Multiaddr, timeline: { open: Date.now() }, - async close () { + async close (options: AbortOptions = {}) { const start = Date.now() + options.signal = options.signal ?? AbortSignal.timeout(CLOSE_TIMEOUT) - try { - await pTimeout(stream.close(), { - milliseconds: CLOSE_TIMEOUT - }) - } catch (err) { + const listener = (): void => { const { host, port } = maConn.remoteAddr.toOptions() log('timeout closing stream to %s:%s after %dms, destroying it manually', host, port, Date.now() - start) - stream.destroy() + this.abort(new CodeError('Socket close timeout', 'ERR_SOCKET_CLOSE_TIMEOUT')) + } + + options.signal.addEventListener('abort', listener) + + try { + await stream.close() + } catch (err: any) { + this.abort(err) } finally { + options.signal.removeEventListener('abort', listener) maConn.timeline.close = Date.now() } + }, + + abort (err: Error): void { + const { host, port } = maConn.remoteAddr.toOptions() + log('timeout closing stream to %s:%s due to error', + host, port, err) + + stream.destroy() + maConn.timeline.close = Date.now() } } diff --git a/packages/transport-webtransport/src/index.ts b/packages/transport-webtransport/src/index.ts index 10bf86da01..951beec20f 100644 --- a/packages/transport-webtransport/src/index.ts +++ b/packages/transport-webtransport/src/index.ts @@ -2,7 +2,7 @@ import { noise } from '@chainsafe/libp2p-noise' import { type Transport, symbol, type CreateListenerOptions, type DialOptions, type Listener } from '@libp2p/interface/transport' import { logger } from '@libp2p/logger' import { peerIdFromString } from '@libp2p/peer-id' -import { type Multiaddr, protocols } from '@multiformats/multiaddr' +import { type Multiaddr, protocols, type AbortOptions } from '@multiformats/multiaddr' import { bases, digest } from 'multiformats/basics' import { Uint8ArrayList } from 'uint8arraylist' import type { Connection, Direction, MultiaddrConnection, Stream } from '@libp2p/interface/connection' @@ -92,50 +92,87 @@ async function webtransportBiDiStreamToStream (bidiStream: any, streamId: string let sinkSunk = false const stream: Stream = { id: streamId, - abort (_err: Error) { + status: 'open', + writeStatus: 'ready', + readStatus: 'ready', + abort (err: Error) { if (!writerClosed) { writer.abort() writerClosed = true } - stream.closeRead() + stream.abort(err) readerClosed = true + + this.status = 'aborted' + this.writeStatus = 'closed' + this.readStatus = 'closed' + + this.timeline.reset = + this.timeline.close = + this.timeline.closeRead = + this.timeline.closeWrite = Date.now() + cleanupStreamFromActiveStreams() }, - close () { - stream.closeRead() - stream.closeWrite() + async close (options?: AbortOptions) { + this.status = 'closing' + + await Promise.all([ + stream.closeRead(options), + stream.closeWrite(options) + ]) + cleanupStreamFromActiveStreams() + + this.status = 'closed' + this.timeline.close = Date.now() }, - closeRead () { + async closeRead (options?: AbortOptions) { if (!readerClosed) { - reader.cancel().catch((err: any) => { + this.readStatus = 'closing' + + try { + await reader.cancel() + } catch (err: any) { if (err.toString().includes('RESET_STREAM') === true) { writerClosed = true } - }) + } + + this.timeline.closeRead = Date.now() + this.readStatus = 'closed' + readerClosed = true } + if (writerClosed) { cleanupStreamFromActiveStreams() } }, - closeWrite () { + + async closeWrite (options?: AbortOptions) { if (!writerClosed) { writerClosed = true - writer.close().catch((err: any) => { + + this.writeStatus = 'closing' + + try { + await writer.close() + } catch (err: any) { if (err.toString().includes('RESET_STREAM') === true) { readerClosed = true } - }) + } + + this.timeline.closeWrite = Date.now() + this.writeStatus = 'closed' } + if (readerClosed) { cleanupStreamFromActiveStreams() } }, - reset () { - stream.close() - }, direction, timeline: { open: Date.now() }, metadata: {}, @@ -159,6 +196,8 @@ async function webtransportBiDiStreamToStream (bidiStream: any, streamId: string } sinkSunk = true try { + this.writeStatus = 'writing' + for await (const chunks of source) { if (chunks instanceof Uint8Array) { await writer.write(chunks) @@ -168,8 +207,13 @@ async function webtransportBiDiStreamToStream (bidiStream: any, streamId: string } } } + + this.writeStatus = 'done' } finally { - stream.closeWrite() + this.timeline.closeWrite = Date.now() + this.writeStatus = 'closed' + + await stream.closeWrite() } } } @@ -325,10 +369,12 @@ class WebTransportTransport implements Transport { } const maConn: MultiaddrConnection = { - close: async (err?: Error) => { - if (err != null) { - log('Closing webtransport with err:', err) - } + close: async (options?: AbortOptions) => { + log('Closing webtransport') + await wt.close() + }, + abort: (err: Error) => { + log('Aborting webtransport with err:', err) wt.close() }, remoteAddr: ma, @@ -462,23 +508,18 @@ class WebTransportTransport implements Transport { /** * Close or abort all tracked streams and stop the muxer */ - close: (err?: Error) => { - if (err != null) { - log('Closing webtransport muxer with err:', err) - } + close: async (options?: AbortOptions) => { + log('Closing webtransport muxer') + await wt.close() + }, + abort: (err: Error) => { + log('Aborting webtransport muxer with err:', err) wt.close() }, // This stream muxer is webtransport native. Therefore it doesn't plug in with any other duplex. ...inertDuplex() } - try { - init?.signal?.throwIfAborted() - } catch (e) { - wt.close() - throw e - } - return muxer } } diff --git a/packages/transport-webtransport/test/browser.ts b/packages/transport-webtransport/test/browser.ts index 07c18d5e0d..3f62bf33ab 100644 --- a/packages/transport-webtransport/test/browser.ts +++ b/packages/transport-webtransport/test/browser.ts @@ -66,7 +66,7 @@ describe('libp2p-webtransport', () => { res = Date.now() - now })()) - stream.close() + await stream.close() expect(res).to.be.greaterThan(-1) } @@ -119,7 +119,7 @@ describe('libp2p-webtransport', () => { // the address is unreachable but we can parse it correctly const stream = await node.dialProtocol(ma, '/ipfs/ping/1.0.0') - stream.close() + await stream.close() await node.stop() }) @@ -162,7 +162,7 @@ describe('libp2p-webtransport', () => { expect(expectedNextNumber).to.equal(16) // Close read, we've should have closed the write side during sink - stream.closeRead() + await stream.closeRead() expect(stream.timeline.close).to.be.greaterThan(0) diff --git a/packages/utils/package.json b/packages/utils/package.json index 31aa72607b..1264dec606 100644 --- a/packages/utils/package.json +++ b/packages/utils/package.json @@ -89,7 +89,6 @@ "@libp2p/interface": "~0.0.1", "@libp2p/logger": "^2.0.0", "@multiformats/multiaddr": "^12.1.3", - "abortable-iterator": "^5.0.1", "is-loopback-addr": "^2.0.1", "it-stream-types": "^2.0.1", "private-ip": "^3.0.0", diff --git a/packages/utils/src/stream-to-ma-conn.ts b/packages/utils/src/stream-to-ma-conn.ts index 4fc812e38e..4e6aa7b86c 100644 --- a/packages/utils/src/stream-to-ma-conn.ts +++ b/packages/utils/src/stream-to-ma-conn.ts @@ -1,36 +1,12 @@ import { logger } from '@libp2p/logger' -import { abortableSource } from 'abortable-iterator' -import type { MultiaddrConnection } from '@libp2p/interface/connection' +import type { AbortOptions } from '@libp2p/interface' +import type { MultiaddrConnection, Stream } from '@libp2p/interface/connection' import type { Multiaddr } from '@multiformats/multiaddr' -import type { Duplex, Source } from 'it-stream-types' -import type { Uint8ArrayList } from 'uint8arraylist' const log = logger('libp2p:stream:converter') -export interface Timeline { - /** - * Connection opening timestamp - */ - open: number - - /** - * Connection upgraded timestamp - */ - upgraded?: number - - /** - * Connection closed timestamp - */ - close?: number -} - -export interface StreamOptions { - signal?: AbortSignal - -} - export interface StreamProperties { - stream: Duplex, Source> + stream: Stream remoteAddr: Multiaddr localAddr: Multiaddr } @@ -39,7 +15,7 @@ export interface StreamProperties { * Convert a duplex iterable into a MultiaddrConnection. * https://github.com/libp2p/interface-transport#multiaddrconnection */ -export function streamToMaConnection (props: StreamProperties, options: StreamOptions = {}): MultiaddrConnection { +export function streamToMaConnection (props: StreamProperties): MultiaddrConnection { const { stream, remoteAddr } = props const { sink, source } = stream @@ -55,13 +31,9 @@ export function streamToMaConnection (props: StreamProperties, options: StreamOp const maConn: MultiaddrConnection = { async sink (source) { - if (options.signal != null) { - source = abortableSource(source, options.signal) - } - try { await sink(source) - await close() + close() } catch (err: any) { // If aborted we can safely ignore if (err.type !== 'aborted') { @@ -72,22 +44,23 @@ export function streamToMaConnection (props: StreamProperties, options: StreamOp } } }, - source: (options.signal != null) ? abortableSource(mapSource, options.signal) : mapSource, + source: mapSource, remoteAddr, timeline: { open: Date.now(), close: undefined }, - async close () { - await sink(async function * () { - yield new Uint8Array(0) - }()) - await close() + async close (options?: AbortOptions) { + close() + await stream.close(options) + }, + abort (err: Error): void { + close() + stream.abort(err) } } - async function close (): Promise { + function close (): void { if (maConn.timeline.close == null) { maConn.timeline.close = Date.now() } - await Promise.resolve() } return maConn diff --git a/packages/utils/test/stream-to-ma-conn.spec.ts b/packages/utils/test/stream-to-ma-conn.spec.ts index 3e58c82d3e..c3f6b5521a 100644 --- a/packages/utils/test/stream-to-ma-conn.spec.ts +++ b/packages/utils/test/stream-to-ma-conn.spec.ts @@ -14,17 +14,19 @@ import type { Uint8ArrayList } from 'uint8arraylist' function toMuxedStream (stream: Duplex, Source, Promise>): Stream { const muxedStream: Stream = { ...stream, - close: () => {}, - closeRead: () => {}, - closeWrite: () => {}, + close: async () => {}, + closeRead: async () => {}, + closeWrite: async () => {}, abort: () => {}, - reset: () => {}, direction: 'outbound', timeline: { open: Date.now() }, metadata: {}, - id: `muxed-stream-${Math.random()}` + id: `muxed-stream-${Math.random()}`, + status: 'open', + readStatus: 'ready', + writeStatus: 'ready' } return muxedStream