diff --git a/packages/grpc-js/src/server-credentials.ts b/packages/grpc-js/src/server-credentials.ts index 17ab29805..0dd5f8cae 100644 --- a/packages/grpc-js/src/server-credentials.ts +++ b/packages/grpc-js/src/server-credentials.ts @@ -26,6 +26,7 @@ export interface KeyCertPair { export abstract class ServerCredentials { abstract _isSecure(): boolean; abstract _getSettings(): SecureServerOptions | null; + abstract _equals(other: ServerCredentials): boolean; static createInsecure(): ServerCredentials { return new InsecureServerCredentials(); @@ -48,8 +49,8 @@ export abstract class ServerCredentials { throw new TypeError('checkClientCertificate must be a boolean'); } - const cert = []; - const key = []; + const cert: Buffer[] = []; + const key: Buffer[] = []; for (let i = 0; i < keyCertPairs.length; i++) { const pair = keyCertPairs[i]; @@ -71,7 +72,7 @@ export abstract class ServerCredentials { } return new SecureServerCredentials({ - ca: rootCerts || getDefaultRootsData() || undefined, + ca: rootCerts ?? getDefaultRootsData() ?? undefined, cert, key, requestCert: checkClientCertificate, @@ -88,6 +89,10 @@ class InsecureServerCredentials extends ServerCredentials { _getSettings(): null { return null; } + + _equals(other: ServerCredentials): boolean { + return other instanceof InsecureServerCredentials; + } } class SecureServerCredentials extends ServerCredentials { @@ -105,4 +110,82 @@ class SecureServerCredentials extends ServerCredentials { _getSettings(): SecureServerOptions { return this.options; } + + /** + * Checks equality by checking the options that are actually set by + * createSsl. + * @param other + * @returns + */ + _equals(other: ServerCredentials): boolean { + if (this === other) { + return true; + } + if (!(other instanceof SecureServerCredentials)) { + return false; + } + // options.ca equality check + if (Buffer.isBuffer(this.options.ca) && Buffer.isBuffer(other.options.ca)) { + if (!this.options.ca.equals(other.options.ca)) { + return false; + } + } else { + if (this.options.ca !== other.options.ca) { + return false; + } + } + // options.cert equality check + if (Array.isArray(this.options.cert) && Array.isArray(other.options.cert)) { + if (this.options.cert.length !== other.options.cert.length) { + return false; + } + for (let i = 0; i < this.options.cert.length; i++) { + const thisCert = this.options.cert[i]; + const otherCert = other.options.cert[i]; + if (Buffer.isBuffer(thisCert) && Buffer.isBuffer(otherCert)) { + if (!thisCert.equals(otherCert)) { + return false; + } + } else { + if (thisCert !== otherCert) { + return false; + } + } + } + } else { + if (this.options.cert !== other.options.cert) { + return false; + } + } + // options.key equality check + if (Array.isArray(this.options.key) && Array.isArray(other.options.key)) { + if (this.options.key.length !== other.options.key.length) { + return false; + } + for (let i = 0; i < this.options.key.length; i++) { + const thisKey = this.options.key[i]; + const otherKey = other.options.key[i]; + if (Buffer.isBuffer(thisKey) && Buffer.isBuffer(otherKey)) { + if (!thisKey.equals(otherKey)) { + return false; + } + } else { + if (thisKey !== otherKey) { + return false; + } + } + } + } else { + if (this.options.key !== other.options.key) { + return false; + } + } + // options.requestCert equality check + if (this.options.requestCert !== other.options.requestCert) { + return false; + } + /* ciphers is derived from a value that is constant for the process, so no + * equality check is needed. */ + return true; + } } diff --git a/packages/grpc-js/src/server.ts b/packages/grpc-js/src/server.ts index f04cd9810..8ff0d0f22 100644 --- a/packages/grpc-js/src/server.ts +++ b/packages/grpc-js/src/server.ts @@ -17,7 +17,6 @@ import * as http2 from 'http2'; import * as util from 'util'; -import { AddressInfo } from 'net'; import { ServiceError } from './call'; import { Status, LogVerbosity } from './constants'; @@ -54,12 +53,11 @@ import { import * as logging from './logging'; import { SubchannelAddress, - TcpSubchannelAddress, isTcpSubchannelAddress, subchannelAddressToString, stringToSubchannelAddress, } from './subchannel-address'; -import { parseUri } from './uri-parser'; +import { GrpcUri, combineHostPort, parseUri, splitHostPort, uriToString } from './uri-parser'; import { ChannelzCallTracker, ChannelzChildrenTracker, @@ -83,9 +81,17 @@ const { HTTP2_HEADER_PATH } = http2.constants; const TRACER_NAME = 'server'; +type AnyHttp2Server = http2.Http2Server | http2.Http2SecureServer; + interface BindResult { port: number; count: number; + errors: string[]; +} + +interface SingleAddressBindResult { + port: number; + error?: string; } function noop(): void {} @@ -161,11 +167,61 @@ interface ChannelzSessionInfo { lastMessageReceivedTimestamp: Date | null; } +/** + * Information related to a single invocation of bindAsync. This should be + * tracked in a map keyed by target string, normalized with a pass through + * parseUri -> mapUriDefaultScheme -> uriToString. If the target has a port + * number and the port number is 0, the target string is modified with the + * concrete bound port. + */ +interface BoundPort { + /** + * The key used to refer to this object in the boundPorts map. + */ + mapKey: string; + /** + * The target string, passed through parseUri -> mapUriDefaultScheme. Used + * to determine the final key when the port number is 0. + */ + originalUri: GrpcUri; + /** + * If there is a pending bindAsync operation, this is a promise that resolves + * with the port number when that operation succeeds. If there is no such + * operation pending, this is null. + */ + completionPromise: Promise | null; + /** + * The port number that was actually bound. Populated only after + * completionPromise resolves. + */ + portNumber: number; + /** + * Set by unbind if called while pending is true. + */ + cancelled: boolean; + /** + * The credentials object passed to the original bindAsync call. + */ + credentials: ServerCredentials; + /** + * The set of servers associated with this listening port. A target string + * that expands to multiple addresses will result in multiple listening + * servers. + */ + listeningServers: Set +} + +/** + * Should be in a map keyed by AnyHttp2Server. + */ +interface Http2ServerInfo { + channelzRef: SocketRef; + sessions: Set; +} + export class Server { - private http2ServerList: { - server: http2.Http2Server | http2.Http2SecureServer; - channelzRef: SocketRef; - }[] = []; + private boundPorts: Map= new Map(); + private http2Servers: Map = new Map(); private handlers: Map = new Map< string, @@ -194,6 +250,12 @@ export class Server { private readonly keepaliveTimeMs: number; private readonly keepaliveTimeoutMs: number; + /** + * Options that will be used to construct all Http2Server instances for this + * Server. + */ + private commonServerOptions: http2.ServerOptions; + constructor(options?: ChannelOptions) { this.options = options ?? {}; if (this.options['grpc.enable_channelz'] === 0) { @@ -215,6 +277,24 @@ export class Server { this.options['grpc.keepalive_time_ms'] ?? KEEPALIVE_MAX_TIME_MS; this.keepaliveTimeoutMs = this.options['grpc.keepalive_timeout_ms'] ?? KEEPALIVE_TIMEOUT_MS; + this.commonServerOptions = { + maxSendHeaderBlockLength: Number.MAX_SAFE_INTEGER, + }; + if ('grpc-node.max_session_memory' in this.options) { + this.commonServerOptions.maxSessionMemory = + this.options['grpc-node.max_session_memory']; + } else { + /* By default, set a very large max session memory limit, to effectively + * disable enforcement of the limit. Some testing indicates that Node's + * behavior degrades badly when this limit is reached, so we solve that + * by disabling the check entirely. */ + this.commonServerOptions.maxSessionMemory = Number.MAX_SAFE_INTEGER; + } + if ('grpc.max_concurrent_streams' in this.options) { + this.commonServerOptions.settings = { + maxConcurrentStreams: this.options['grpc.max_concurrent_streams'], + }; + } this.trace('Server constructed'); } @@ -382,6 +462,238 @@ export class Server { throw new Error('Not implemented. Use bindAsync() instead'); } + private registerListenerToChannelz(boundAddress: SubchannelAddress) { + return registerChannelzSocket( + subchannelAddressToString(boundAddress), + () => { + return { + localAddress: boundAddress, + remoteAddress: null, + security: null, + remoteName: null, + streamsStarted: 0, + streamsSucceeded: 0, + streamsFailed: 0, + messagesSent: 0, + messagesReceived: 0, + keepAlivesSent: 0, + lastLocalStreamCreatedTimestamp: null, + lastRemoteStreamCreatedTimestamp: null, + lastMessageSentTimestamp: null, + lastMessageReceivedTimestamp: null, + localFlowControlWindow: null, + remoteFlowControlWindow: null, + }; + }, + this.channelzEnabled + ); + } + + private createHttp2Server(credentials: ServerCredentials) { + let http2Server: http2.Http2Server | http2.Http2SecureServer; + if (credentials._isSecure()) { + const secureServerOptions = Object.assign( + this.commonServerOptions, + credentials._getSettings()! + ); + secureServerOptions.enableTrace = + this.options['grpc-node.tls_enable_trace'] === 1; + http2Server = http2.createSecureServer(secureServerOptions); + http2Server.on('secureConnection', (socket: TLSSocket) => { + /* These errors need to be handled by the user of Http2SecureServer, + * according to https://github.com/nodejs/node/issues/35824 */ + socket.on('error', (e: Error) => { + this.trace( + 'An incoming TLS connection closed with error: ' + e.message + ); + }); + }); + } else { + http2Server = http2.createServer(this.commonServerOptions); + } + + http2Server.setTimeout(0, noop); + this._setupHandlers(http2Server); + return http2Server; + } + + private bindOneAddress(address: SubchannelAddress, boundPortObject: BoundPort): Promise { + this.trace( + 'Attempting to bind ' + subchannelAddressToString(address) + ); + const http2Server = this.createHttp2Server(boundPortObject.credentials); + return new Promise((resolve, reject) => { + const onError = (err: Error) => { + this.trace( + 'Failed to bind ' + + subchannelAddressToString(address) + + ' with error ' + + err.message + ); + resolve({ + port: 'port' in address ? address.port : 1, + error: err.message + }); + }; + + http2Server.once('error', onError); + + http2Server.listen(address, () => { + const boundAddress = http2Server.address()!; + let boundSubchannelAddress: SubchannelAddress; + if (typeof boundAddress === 'string') { + boundSubchannelAddress = { + path: boundAddress, + }; + } else { + boundSubchannelAddress = { + host: boundAddress.address, + port: boundAddress.port, + }; + } + + const channelzRef = this.registerListenerToChannelz(boundSubchannelAddress); + if (this.channelzEnabled) { + this.listenerChildrenTracker.refChild(channelzRef); + } + this.http2Servers.set(http2Server, { + channelzRef: channelzRef, + sessions: new Set() + }); + boundPortObject.listeningServers.add(http2Server); + this.trace( + 'Successfully bound ' + + subchannelAddressToString(boundSubchannelAddress) + ); + resolve({ + port: 'port' in boundSubchannelAddress + ? boundSubchannelAddress.port + : 1 + }); + http2Server.removeListener('error', onError); + }); + }); + } + + private async bindManyPorts(addressList: SubchannelAddress[], boundPortObject: BoundPort): Promise { + if (addressList.length === 0) { + return { + count: 0, + port: 0, + errors: [] + }; + } + if (isTcpSubchannelAddress(addressList[0]) && addressList[0].port === 0) { + /* If binding to port 0, first try to bind the first address, then bind + * the rest of the address list to the specific port that it binds. */ + const firstAddressResult = await this.bindOneAddress(addressList[0], boundPortObject); + if (firstAddressResult.error) { + /* If the first address fails to bind, try the same operation starting + * from the second item in the list. */ + const restAddressResult = await this.bindManyPorts(addressList.slice(1), boundPortObject); + return { + ...restAddressResult, + errors: [firstAddressResult.error, ...restAddressResult.errors] + }; + } else { + const restAddresses = addressList.slice(1).map(address => isTcpSubchannelAddress(address) ? {host: address.host, port: firstAddressResult.port} : address) + const restAddressResult = await Promise.all(restAddresses.map(address => this.bindOneAddress(address, boundPortObject))); + const allResults = [firstAddressResult, ...restAddressResult]; + return { + count: allResults.filter(result => result.error === undefined).length, + port: firstAddressResult.port, + errors: allResults.filter(result => result.error).map(result => result.error!) + }; + } + } else { + const allResults = await Promise.all(addressList.map(address => this.bindOneAddress(address, boundPortObject))); + return { + count: allResults.filter(result => result.error === undefined).length, + port: allResults[0].port, + errors: allResults.filter(result => result.error).map(result => result.error!) + }; + } + } + + private async bindAddressList(addressList: SubchannelAddress[], boundPortObject: BoundPort): Promise { + let bindResult: BindResult; + try { + bindResult = await this.bindManyPorts(addressList, boundPortObject); + } catch (error) { + throw error; + } + if (bindResult.count > 0) { + if (bindResult.count < addressList.length) { + logging.log( + LogVerbosity.INFO, + `WARNING Only ${bindResult.count} addresses added out of total ${addressList.length} resolved` + ); + } + return bindResult.port; + } else { + const errorString = `No address added out of total ${addressList.length} resolved`; + logging.log(LogVerbosity.ERROR, errorString); + throw new Error(`${errorString} errors: [${bindResult.errors.join(',')}]`); + } + } + + private resolvePort(port: GrpcUri): Promise { + return new Promise((resolve, reject) => { + const resolverListener: ResolverListener = { + onSuccessfulResolution: ( + endpointList, + serviceConfig, + serviceConfigError + ) => { + // We only want one resolution result. Discard all future results + resolverListener.onSuccessfulResolution = () => {}; + const addressList = ([] as SubchannelAddress[]).concat( + ...endpointList.map(endpoint => endpoint.addresses) + ); + if (addressList.length === 0) { + reject( + new Error(`No addresses resolved for port ${port}`) + ); + return; + } + resolve(addressList); + }, + onError: error => { + reject(new Error(error.details)); + }, + }; + const resolver = createResolver(port, resolverListener, this.options); + resolver.updateResolution(); + }); + } + + private async bindPort(port: GrpcUri, boundPortObject: BoundPort): Promise { + const addressList = await this.resolvePort(port); + if (boundPortObject.cancelled) { + this.completeUnbind(boundPortObject); + throw new Error('bindAsync operation cancelled by unbind call'); + } + const portNumber = await this.bindAddressList(addressList, boundPortObject); + if (boundPortObject.cancelled) { + this.completeUnbind(boundPortObject); + throw new Error('bindAsync operation cancelled by unbind call'); + } + return portNumber; + } + + private normalizePort(port: string): GrpcUri { + + const initialPortUri = parseUri(port); + if (initialPortUri === null) { + throw new Error(`Could not parse port "${port}"`); + } + const portUri = mapUriDefaultScheme(initialPortUri); + if (portUri === null) { + throw new Error(`Could not get a default scheme for port "${port}"`); + } + return portUri; + } + bindAsync( port: string, creds: ServerCredentials, @@ -399,331 +711,162 @@ export class Server { throw new TypeError('callback must be a function'); } - const initialPortUri = parseUri(port); - if (initialPortUri === null) { - throw new Error(`Could not parse port "${port}"`); - } - const portUri = mapUriDefaultScheme(initialPortUri); - if (portUri === null) { - throw new Error(`Could not get a default scheme for port "${port}"`); - } + this.trace('bindAsync port=' + port); - const serverOptions: http2.ServerOptions = { - maxSendHeaderBlockLength: Number.MAX_SAFE_INTEGER, - }; - if ('grpc-node.max_session_memory' in this.options) { - serverOptions.maxSessionMemory = - this.options['grpc-node.max_session_memory']; - } else { - /* By default, set a very large max session memory limit, to effectively - * disable enforcement of the limit. Some testing indicates that Node's - * behavior degrades badly when this limit is reached, so we solve that - * by disabling the check entirely. */ - serverOptions.maxSessionMemory = Number.MAX_SAFE_INTEGER; - } - if ('grpc.max_concurrent_streams' in this.options) { - serverOptions.settings = { - maxConcurrentStreams: this.options['grpc.max_concurrent_streams'], - }; - } + const portUri = this.normalizePort(port); const deferredCallback = (error: Error | null, port: number) => { process.nextTick(() => callback(error, port)); }; - const setupServer = (): http2.Http2Server | http2.Http2SecureServer => { - let http2Server: http2.Http2Server | http2.Http2SecureServer; - if (creds._isSecure()) { - const secureServerOptions = Object.assign( - serverOptions, - creds._getSettings()! - ); - secureServerOptions.enableTrace = - this.options['grpc-node.tls_enable_trace'] === 1; - http2Server = http2.createSecureServer(secureServerOptions); - http2Server.on('secureConnection', (socket: TLSSocket) => { - /* These errors need to be handled by the user of Http2SecureServer, - * according to https://github.com/nodejs/node/issues/35824 */ - socket.on('error', (e: Error) => { - this.trace( - 'An incoming TLS connection closed with error: ' + e.message - ); - }); - }); + /* First, if this port is already bound or that bind operation is in + * progress, use that result. */ + let boundPortObject = this.boundPorts.get(uriToString(portUri)); + if (boundPortObject) { + if (!creds._equals(boundPortObject.credentials)) { + deferredCallback(new Error(`${port} already bound with incompatible credentials`), 0); + return; + } + /* If that operation has previously been cancelled by an unbind call, + * uncancel it. */ + boundPortObject.cancelled = false; + if (boundPortObject.completionPromise) { + boundPortObject.completionPromise.then(portNum => callback(null, portNum), error => callback(error as Error, 0)); } else { - http2Server = http2.createServer(serverOptions); + deferredCallback(null, boundPortObject.portNumber); } - - http2Server.setTimeout(0, noop); - this._setupHandlers(http2Server); - return http2Server; + return; + } + boundPortObject = { + mapKey: uriToString(portUri), + originalUri: portUri, + completionPromise: null, + cancelled: false, + portNumber: 0, + credentials: creds, + listeningServers: new Set() }; - - const bindSpecificPort = ( - addressList: SubchannelAddress[], - portNum: number, - previousCount: number - ): Promise => { - if (addressList.length === 0) { - return Promise.resolve({ port: portNum, count: previousCount }); - } - return Promise.all( - addressList.map(address => { - this.trace( - 'Attempting to bind ' + subchannelAddressToString(address) - ); - let addr: SubchannelAddress; - if (isTcpSubchannelAddress(address)) { - addr = { - host: (address as TcpSubchannelAddress).host, - port: portNum, - }; - } else { - addr = address; - } - - const http2Server = setupServer(); - return new Promise((resolve, reject) => { - const onError = (err: Error) => { - this.trace( - 'Failed to bind ' + - subchannelAddressToString(address) + - ' with error ' + - err.message - ); - resolve(err); - }; - - http2Server.once('error', onError); - - http2Server.listen(addr, () => { - const boundAddress = http2Server.address()!; - let boundSubchannelAddress: SubchannelAddress; - if (typeof boundAddress === 'string') { - boundSubchannelAddress = { - path: boundAddress, - }; - } else { - boundSubchannelAddress = { - host: boundAddress.address, - port: boundAddress.port, - }; - } - - const channelzRef = registerChannelzSocket( - subchannelAddressToString(boundSubchannelAddress), - () => { - return { - localAddress: boundSubchannelAddress, - remoteAddress: null, - security: null, - remoteName: null, - streamsStarted: 0, - streamsSucceeded: 0, - streamsFailed: 0, - messagesSent: 0, - messagesReceived: 0, - keepAlivesSent: 0, - lastLocalStreamCreatedTimestamp: null, - lastRemoteStreamCreatedTimestamp: null, - lastMessageSentTimestamp: null, - lastMessageReceivedTimestamp: null, - localFlowControlWindow: null, - remoteFlowControlWindow: null, - }; - }, - this.channelzEnabled - ); - if (this.channelzEnabled) { - this.listenerChildrenTracker.refChild(channelzRef); - } - this.http2ServerList.push({ - server: http2Server, - channelzRef: channelzRef, - }); - this.trace( - 'Successfully bound ' + - subchannelAddressToString(boundSubchannelAddress) - ); - resolve( - 'port' in boundSubchannelAddress - ? boundSubchannelAddress.port - : portNum - ); - http2Server.removeListener('error', onError); - }); - }); - }) - ).then(results => { - let count = 0; - for (const result of results) { - if (typeof result === 'number') { - count += 1; - if (result !== portNum) { - throw new Error( - 'Invalid state: multiple port numbers added from single address' - ); - } - } - } - return { - port: portNum, - count: count + previousCount, + const splitPort = splitHostPort(portUri.path); + const completionPromise = this.bindPort(portUri, boundPortObject); + boundPortObject.completionPromise = completionPromise; + /* If the port number is 0, defer populating the map entry until after the + * bind operation completes and we have a specific port number. Otherwise, + * populate it immediately. */ + if (splitPort?.port === 0) { + completionPromise.then(portNum => { + const finalUri: GrpcUri = { + scheme: portUri.scheme, + authority: portUri.authority, + path: combineHostPort({host: splitPort.host, port: portNum}) }; + boundPortObject!.mapKey = uriToString(finalUri); + boundPortObject!.completionPromise = null; + boundPortObject!.portNumber = portNum; + this.boundPorts.set(boundPortObject!.mapKey, boundPortObject!); + callback(null, portNum); + }, error => { + callback(error, 0); + }) + } else { + this.boundPorts.set(boundPortObject.mapKey, boundPortObject); + completionPromise.then(portNum => { + boundPortObject!.completionPromise = null; + boundPortObject!.portNumber = portNum; + callback(null, portNum); + }, error => { + callback(error, 0); }); - }; + } + } - const bindWildcardPort = ( - addressList: SubchannelAddress[] - ): Promise => { - if (addressList.length === 0) { - return Promise.resolve({ port: 0, count: 0 }); + private closeServer(server: AnyHttp2Server, callback?: () => void) { + this.trace('Closing server with address ' + JSON.stringify(server.address())); + const serverInfo = this.http2Servers.get(server); + server.close(() => { + if (this.channelzEnabled && serverInfo) { + this.listenerChildrenTracker.unrefChild(serverInfo.channelzRef); + unregisterChannelzRef(serverInfo.channelzRef); } - const address = addressList[0]; - const http2Server = setupServer(); - return new Promise((resolve, reject) => { - const onError = (err: Error) => { - this.trace( - 'Failed to bind ' + - subchannelAddressToString(address) + - ' with error ' + - err.message - ); - resolve(bindWildcardPort(addressList.slice(1))); - }; + this.http2Servers.delete(server); + callback?.(); + }); - http2Server.once('error', onError); + } - http2Server.listen(address, () => { - const boundAddress = http2Server.address() as AddressInfo; - const boundSubchannelAddress: SubchannelAddress = { - host: boundAddress.address, - port: boundAddress.port, - }; - const channelzRef = registerChannelzSocket( - subchannelAddressToString(boundSubchannelAddress), - () => { - return { - localAddress: boundSubchannelAddress, - remoteAddress: null, - security: null, - remoteName: null, - streamsStarted: 0, - streamsSucceeded: 0, - streamsFailed: 0, - messagesSent: 0, - messagesReceived: 0, - keepAlivesSent: 0, - lastLocalStreamCreatedTimestamp: null, - lastRemoteStreamCreatedTimestamp: null, - lastMessageSentTimestamp: null, - lastMessageReceivedTimestamp: null, - localFlowControlWindow: null, - remoteFlowControlWindow: null, - }; - }, - this.channelzEnabled - ); - if (this.channelzEnabled) { - this.listenerChildrenTracker.refChild(channelzRef); - } - this.http2ServerList.push({ - server: http2Server, - channelzRef: channelzRef, - }); - this.trace( - 'Successfully bound ' + - subchannelAddressToString(boundSubchannelAddress) - ); - resolve(bindSpecificPort(addressList.slice(1), boundAddress.port, 1)); - http2Server.removeListener('error', onError); - }); - }); + private closeSession(session: http2.ServerHttp2Session, callback?: () => void) { + this.trace('Closing session initiated by ' + session.socket?.remoteAddress); + const sessionInfo = this.sessions.get(session); + const closeCallback = () => { + if (this.channelzEnabled && sessionInfo) { + this.sessionChildrenTracker.unrefChild(sessionInfo.ref); + unregisterChannelzRef(sessionInfo.ref); + } + this.sessions.delete(session); + callback?.(); }; + if (session.closed) { + process.nextTick(closeCallback); + } else { + session.close(closeCallback); + } + } - const resolverListener: ResolverListener = { - onSuccessfulResolution: ( - endpointList, - serviceConfig, - serviceConfigError - ) => { - // We only want one resolution result. Discard all future results - resolverListener.onSuccessfulResolution = () => {}; - const addressList = ([] as SubchannelAddress[]).concat( - ...endpointList.map(endpoint => endpoint.addresses) - ); - if (addressList.length === 0) { - deferredCallback( - new Error(`No addresses resolved for port ${port}`), - 0 - ); - return; - } - let bindResultPromise: Promise; - if (isTcpSubchannelAddress(addressList[0])) { - if (addressList[0].port === 0) { - bindResultPromise = bindWildcardPort(addressList); - } else { - bindResultPromise = bindSpecificPort( - addressList, - addressList[0].port, - 0 - ); - } - } else { - // Use an arbitrary non-zero port for non-TCP addresses - bindResultPromise = bindSpecificPort(addressList, 1, 0); + private completeUnbind(boundPortObject: BoundPort) { + for (const server of boundPortObject.listeningServers) { + const serverInfo = this.http2Servers.get(server); + this.closeServer(server, () => { + boundPortObject.listeningServers.delete(server); + }); + if (serverInfo) { + for (const session of serverInfo.sessions) { + this.closeSession(session); } - bindResultPromise.then( - bindResult => { - if (bindResult.count === 0) { - const errorString = `No address added out of total ${addressList.length} resolved`; - logging.log(LogVerbosity.ERROR, errorString); - deferredCallback(new Error(errorString), 0); - } else { - if (bindResult.count < addressList.length) { - logging.log( - LogVerbosity.INFO, - `WARNING Only ${bindResult.count} addresses added out of total ${addressList.length} resolved` - ); - } - deferredCallback(null, bindResult.port); - } - }, - error => { - const errorString = `No address added out of total ${addressList.length} resolved`; - logging.log(LogVerbosity.ERROR, errorString); - deferredCallback(new Error(errorString), 0); - } - ); - }, - onError: error => { - deferredCallback(new Error(error.details), 0); - }, - }; + } + } + this.boundPorts.delete(boundPortObject.mapKey); + } - const resolver = createResolver(portUri, resolverListener, this.options); - resolver.updateResolution(); + /** + * Unbind a previously bound port, or cancel an in-progress bindAsync + * operation. If port 0 was bound, only the actual bound port can be + * unbound. For example, if bindAsync was called with "localhost:0" and the + * bound port result was 54321, it can be unbound as "localhost:54321". + * @param port + */ + unbind(port: string): void { + this.trace('unbind port=' + port); + const portUri = this.normalizePort(port); + const splitPort = splitHostPort(portUri.path); + if (splitPort?.port === 0) { + throw new Error('Cannot unbind port 0'); + } + const boundPortObject = this.boundPorts.get(uriToString(portUri)); + if (boundPortObject) { + this.trace('unbinding ' + boundPortObject.mapKey + ' originally bound as ' + uriToString(boundPortObject.originalUri)); + /* If the bind operation is pending, the cancelled flag will trigger + * the unbind operation later. */ + if (boundPortObject.completionPromise) { + boundPortObject.cancelled = true; + } else { + this.completeUnbind(boundPortObject); + } + } } forceShutdown(): void { + for (const boundPortObject of this.boundPorts.values()) { + boundPortObject.cancelled = true; + } + this.boundPorts.clear(); // Close the server if it is still running. - - for (const { server: http2Server, channelzRef: ref } of this - .http2ServerList) { - if (http2Server.listening) { - http2Server.close(() => { - if (this.channelzEnabled) { - this.listenerChildrenTracker.unrefChild(ref); - unregisterChannelzRef(ref); - } - }); - } + for (const server of this.http2Servers.keys()) { + this.closeServer(server); } // Always destroy any available sessions. It's possible that one or more // tryShutdown() calls are in progress. Don't wait on them to finish. this.sessions.forEach((channelzInfo, session) => { + this.closeSession(session); // Cast NGHTTP2_CANCEL to any because TypeScript doesn't seem to // recognize destroy(code) as a valid signature. // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -766,9 +909,9 @@ export class Server { @deprecate('Calling start() is no longer necessary. It can be safely omitted.') start(): void { if ( - this.http2ServerList.length === 0 || - this.http2ServerList.every( - ({ server: http2Server }) => http2Server.listening !== true + this.http2Servers.size === 0 || + [...this.http2Servers.keys()].every( + server => !server.listening ) ) { throw new Error('server must be bound in order to start'); @@ -797,26 +940,24 @@ export class Server { } } - for (const { server: http2Server, channelzRef: ref } of this - .http2ServerList) { - if (http2Server.listening) { - pendingChecks++; - http2Server.close(() => { - if (this.channelzEnabled) { - this.listenerChildrenTracker.unrefChild(ref); - unregisterChannelzRef(ref); - } - maybeCallback(); - }); - } + for (const server of this.http2Servers.keys()) { + pendingChecks++; + const serverString = this.http2Servers.get(server)!.channelzRef.name; + this.trace('Waiting for server ' + serverString + ' to close'); + this.closeServer(server, () => { + this.trace('Server ' + serverString + ' finished closing'); + maybeCallback(); + }); + } + for (const session of this.sessions.keys()) { + pendingChecks++; + const sessionString = session.socket?.remoteAddress; + this.trace('Waiting for session ' + sessionString + ' to close'); + this.closeSession(session, () => { + this.trace('Session ' + sessionString + ' finished closing'); + maybeCallback(); + }); } - - this.sessions.forEach((channelzInfo, session) => { - if (!session.closed) { - pendingChecks += 1; - session.close(maybeCallback); - } - }); if (pendingChecks === 0) { wrappedCallback(); } @@ -1077,6 +1218,7 @@ export class Server { lastMessageReceivedTimestamp: null, }; + this.http2Servers.get(http2Server)?.sessions.add(session); this.sessions.set(session, channelzSessionInfo); const clientAddress = session.socket.remoteAddress; if (this.channelzEnabled) { @@ -1164,6 +1306,7 @@ export class Server { if (keeapliveTimeTimer) { clearTimeout(keeapliveTimeTimer); } + this.http2Servers.get(http2Server)?.sessions.delete(session); this.sessions.delete(session); }); }); diff --git a/packages/grpc-js/src/subchannel.ts b/packages/grpc-js/src/subchannel.ts index cf49ced77..d9a2dbd80 100644 --- a/packages/grpc-js/src/subchannel.ts +++ b/packages/grpc-js/src/subchannel.ts @@ -120,6 +120,7 @@ export class Subchannel { this.backoffTimeout = new BackoffTimeout(() => { this.handleBackoffTimer(); }, backoffOptions); + this.backoffTimeout.unref(); this.subchannelAddressString = subchannelAddressToString(subchannelAddress); this.keepaliveTime = options['grpc.keepalive_time_ms'] ?? -1; diff --git a/packages/grpc-js/src/uri-parser.ts b/packages/grpc-js/src/uri-parser.ts index 20c3d53b3..2b2efeca0 100644 --- a/packages/grpc-js/src/uri-parser.ts +++ b/packages/grpc-js/src/uri-parser.ts @@ -101,6 +101,19 @@ export function splitHostPort(path: string): HostPort | null { } } +export function combineHostPort(hostPort: HostPort): string { + if (hostPort.port === undefined) { + return hostPort.host; + } else { + // Only an IPv6 host should include a colon + if (hostPort.host.includes(':')) { + return `[${hostPort.host}]:${hostPort.port}`; + } else { + return `${hostPort.host}:${hostPort.port}`; + } + } +} + export function uriToString(uri: GrpcUri): string { let result = ''; if (uri.scheme !== undefined) { diff --git a/packages/grpc-js/test/test-server.ts b/packages/grpc-js/test/test-server.ts index d1b485ec3..56388a868 100644 --- a/packages/grpc-js/test/test-server.ts +++ b/packages/grpc-js/test/test-server.ts @@ -63,6 +63,13 @@ const cert = fs.readFileSync(path.join(__dirname, 'fixtures', 'server1.pem')); function noop(): void {} describe('Server', () => { + let server: Server; + beforeEach(() => { + server = new Server(); + }); + afterEach(() => { + server.forceShutdown(); + }); describe('constructor', () => { it('should work with no arguments', () => { assert.doesNotThrow(() => { @@ -140,6 +147,85 @@ describe('Server', () => { ); }, /callback must be a function/); }); + + it('succeeds when called with an already bound port', done => { + server.bindAsync('localhost:0', ServerCredentials.createInsecure(), (err, port) => { + assert.ifError(err); + server.bindAsync(`localhost:${port}`, ServerCredentials.createInsecure(), (err2, port2) => { + assert.ifError(err2); + assert.strictEqual(port, port2); + done(); + }); + }); + }); + + it('fails when called on a bound port with different credentials', done => { + const secureCreds = ServerCredentials.createSsl( + ca, + [{ private_key: key, cert_chain: cert }], + true + ); + server.bindAsync('localhost:0', ServerCredentials.createInsecure(), (err, port) => { + assert.ifError(err); + server.bindAsync(`localhost:${port}`, secureCreds, (err2, port2) => { + assert(err2 !== null); + assert.match(err2.message, /credentials/); + done(); + }) + }); + }) + }); + + describe('unbind', () => { + let client: grpc.Client | null = null; + beforeEach(() => { + client = null; + }); + afterEach(() => { + client?.close(); + }); + it('refuses to unbind port 0', done => { + assert.throws(() => { + server.unbind('localhost:0'); + }, /port 0/); + server.bindAsync('localhost:0', ServerCredentials.createInsecure(), (err, port) => { + assert.ifError(err); + assert.notStrictEqual(port, 0); + assert.throws(() => { + server.unbind('localhost:0'); + }, /port 0/); + done(); + }) + }); + + it('successfully unbinds a bound ephemeral port', done => { + server.bindAsync('localhost:0', ServerCredentials.createInsecure(), (err, port) => { + client = new grpc.Client(`localhost:${port}`, grpc.credentials.createInsecure()); + client.makeUnaryRequest('/math.Math/Div', x => x, x => x, Buffer.from('abc'), (callError1, result) => { + assert(callError1); + // UNIMPLEMENTED means that the request reached the call handling code + assert.strictEqual(callError1.code, grpc.status.UNIMPLEMENTED); + server.unbind(`localhost:${port}`); + const deadline = new Date(); + deadline.setSeconds(deadline.getSeconds() + 1); + client!.makeUnaryRequest('/math.Math/Div', x => x, x => x, Buffer.from('abc'), {deadline: deadline}, (callError2, result) => { + assert(callError2); + // DEADLINE_EXCEEDED means that the server is unreachable + assert.strictEqual(callError2.code, grpc.status.DEADLINE_EXCEEDED); + done(); + }); + }); + }) + }); + + it('cancels a bindAsync in progress', done => { + server.bindAsync('localhost:50051', ServerCredentials.createInsecure(), (err, port) => { + assert(err); + assert.match(err.message, /cancelled by unbind/); + done(); + }); + server.unbind('localhost:50051'); + }); }); describe('start', () => {