Skip to content

Commit

Permalink
Merge pull request #11019 from nestjs/fix/gracefully-reconnect-rmq
Browse files Browse the repository at this point in the history
fix(microservices): rmq should gracefully reconnect upon error
  • Loading branch information
kamilmysliwiec committed Feb 2, 2023
2 parents 4ad3cbc + 9eb3b89 commit 2ed509a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 58 deletions.
122 changes: 77 additions & 45 deletions packages/microservices/client/client-rmq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,60 @@ import { loadPackage } from '@nestjs/common/utils/load-package.util';
import { randomStringGenerator } from '@nestjs/common/utils/random-string-generator.util';
import { isFunction } from '@nestjs/common/utils/shared.utils';
import { EventEmitter } from 'events';
import { EmptyError, fromEvent, lastValueFrom, merge, Observable } from 'rxjs';
import { first, map, retryWhen, scan, share, switchMap } from 'rxjs/operators';
import {
EmptyError,
firstValueFrom,
fromEvent,
merge,
Observable,
ReplaySubject,
} from 'rxjs';
import { first, map, retryWhen, scan, skip, switchMap } from 'rxjs/operators';
import {
CONNECT_EVENT,
CONNECT_FAILED_EVENT,
DISCONNECTED_RMQ_MESSAGE,
DISCONNECT_EVENT,
ERROR_EVENT,
RQM_DEFAULT_IS_GLOBAL_PREFETCH_COUNT,
RQM_DEFAULT_NOACK,
RQM_DEFAULT_NO_ASSERT,
RQM_DEFAULT_PERSISTENT,
RQM_DEFAULT_PREFETCH_COUNT,
RQM_DEFAULT_QUEUE,
RQM_DEFAULT_QUEUE_OPTIONS,
RQM_DEFAULT_URL,
RQM_DEFAULT_NO_ASSERT,
} from '../constants';
import { RmqUrl } from '../external/rmq-url.interface';
import { ReadPacket, RmqOptions, WritePacket } from '../interfaces';
import { RmqRecord } from '../record-builders';
import { RmqRecordSerializer } from '../serializers/rmq-record.serializer';
import { ClientProxy } from './client-proxy';

// import type {
// AmqpConnectionManager,
// ChannelWrapper,
// } from 'amqp-connection-manager';
// import type { Channel, ConsumeMessage } from 'amqplib';

type Channel = any;
type ChannelWrapper = any;
type ConsumeMessage = any;
type AmqpConnectionManager = any;

let rqmPackage: any = {};

const REPLY_QUEUE = 'amq.rabbitmq.reply-to';

export class ClientRMQ extends ClientProxy {
protected readonly logger = new Logger(ClientProxy.name);
protected connection$: ReplaySubject<any>;
protected connection: Promise<any>;
protected client: any = null;
protected channel: any = null;
protected client: AmqpConnectionManager = null;
protected channel: ChannelWrapper = null;
protected urls: string[] | RmqUrl[];
protected queue: string;
protected queueOptions: any;
protected queueOptions: Record<string, any>;
protected responseEmitter: EventEmitter;
protected replyQueue: string;
protected persistent: boolean;
Expand Down Expand Up @@ -75,42 +95,44 @@ export class ClientRMQ extends ClientProxy {

public connect(): Promise<any> {
if (this.client) {
return this.connection;
return this.convertConnectionToPromise();
}
this.client = this.createClient();
this.handleError(this.client);
this.handleDisconnectError(this.client);

this.responseEmitter = new EventEmitter();
this.responseEmitter.setMaxListeners(0);

const connect$ = this.connect$(this.client);
this.connection = lastValueFrom(
this.mergeDisconnectEvent(this.client, connect$).pipe(
switchMap(() => this.createChannel()),
share(),
),
).catch(err => {
if (err instanceof EmptyError) {
return;
}
throw err;
});
const withDisconnect$ = this.mergeDisconnectEvent(
this.client,
connect$,
).pipe(switchMap(() => this.createChannel()));

const withReconnect$ = fromEvent(this.client, CONNECT_EVENT).pipe(skip(1));
const source$ = merge(withDisconnect$, withReconnect$);

return this.connection;
this.connection$ = new ReplaySubject(1);
source$.subscribe(this.connection$);

return this.convertConnectionToPromise();
}

public createChannel(): Promise<void> {
return new Promise(resolve => {
this.channel = this.client.createChannel({
json: false,
setup: (channel: any) => this.setupChannel(channel, resolve),
setup: (channel: Channel) => this.setupChannel(channel, resolve),
});
});
}

public createClient<T = any>(): T {
public createClient(): AmqpConnectionManager {
const socketOptions = this.getOptionsProp(this.options, 'socketOptions');
return rqmPackage.connect(this.urls, {
connectionOptions: socketOptions,
}) as T;
});
}

public mergeDisconnectEvent<T = any>(
Expand All @@ -119,7 +141,7 @@ export class ClientRMQ extends ClientProxy {
): Observable<T> {
const eventToError = (eventType: string) =>
fromEvent(instance, eventType).pipe(
map((err: any) => {
map((err: unknown) => {
throw err;
}),
);
Expand All @@ -138,10 +160,23 @@ export class ClientRMQ extends ClientProxy {
),
),
);
// If we ever decide to propagate all disconnect errors & re-emit them through
// the "connection" stream then comment out "first()" operator.
return merge(source$, disconnect$, connectFailed$).pipe(first());
}

public async setupChannel(channel: any, resolve: Function) {
public async convertConnectionToPromise() {
try {
return await firstValueFrom(this.connection$);
} catch (err) {
if (err instanceof EmptyError) {
return;
}
throw err;
}
}

public async setupChannel(channel: Channel, resolve: Function) {
const prefetchCount =
this.getOptionsProp(this.options, 'prefetchCount') ||
RQM_DEFAULT_PREFETCH_COUNT;
Expand All @@ -153,35 +188,30 @@ export class ClientRMQ extends ClientProxy {
await channel.assertQueue(this.queue, this.queueOptions);
}
await channel.prefetch(prefetchCount, isGlobalPrefetchCount);

this.responseEmitter = new EventEmitter();
this.responseEmitter.setMaxListeners(0);
await this.consumeChannel(channel);
resolve();
}

public async consumeChannel(channel: any) {
public async consumeChannel(channel: Channel) {
const noAck = this.getOptionsProp(this.options, 'noAck', RQM_DEFAULT_NOACK);
await channel.consume(
this.replyQueue,
(msg: any) =>
(msg: ConsumeMessage) =>
this.responseEmitter.emit(msg.properties.correlationId, msg),
{
noAck,
},
);
}

public handleError(client: any): void {
public handleError(client: AmqpConnectionManager): void {
client.addListener(ERROR_EVENT, (err: any) => this.logger.error(err));
}

public handleDisconnectError(client: any): void {
public handleDisconnectError(client: AmqpConnectionManager): void {
client.addListener(DISCONNECT_EVENT, (err: any) => {
this.logger.error(DISCONNECTED_RMQ_MESSAGE);
this.logger.error(err);

this.close();
});
}

Expand Down Expand Up @@ -231,7 +261,7 @@ export class ClientRMQ extends ClientProxy {
content,
options,
}: {
content: any;
content: Buffer;
options: Record<string, unknown>;
}) =>
this.handleMessage(JSON.parse(content.toString()), options, callback);
Expand All @@ -244,17 +274,19 @@ export class ClientRMQ extends ClientProxy {
delete serializedPacket.options;

this.responseEmitter.on(correlationId, listener);
this.channel.sendToQueue(
this.queue,
Buffer.from(JSON.stringify(serializedPacket)),
{
replyTo: this.replyQueue,
persistent: this.persistent,
...options,
headers: this.mergeHeaders(options?.headers),
correlationId,
},
);
this.channel
.sendToQueue(
this.queue,
Buffer.from(JSON.stringify(serializedPacket)),
{
replyTo: this.replyQueue,
persistent: this.persistent,
...options,
headers: this.mergeHeaders(options?.headers),
correlationId,
},
)
.catch(err => callback({ err }));
return () => this.responseEmitter.removeListener(correlationId, listener);
} catch (err) {
callback({ err });
Expand Down
25 changes: 12 additions & 13 deletions packages/microservices/test/client/client-rmq.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ describe('ClientRMQ', function () {
let createClientStub: sinon.SinonStub;
let handleErrorsSpy: sinon.SinonSpy;
let connect$Stub: sinon.SinonStub;
let mergeDisconnectEvent: sinon.SinonStub;

beforeEach(async () => {
client = new ClientRMQ({});
Expand All @@ -33,7 +32,7 @@ describe('ClientRMQ', function () {
return this;
},
}));
mergeDisconnectEvent = sinon
sinon
.stub(client, 'mergeDisconnectEvent')
.callsFake((_, source) => source);
});
Expand Down Expand Up @@ -173,18 +172,18 @@ describe('ClientRMQ', function () {
const pattern = 'test';
let msg: ReadPacket;
let connectSpy: sinon.SinonSpy,
sendToQueueSpy: sinon.SinonSpy,
sendToQueueStub: sinon.SinonStub,
eventSpy: sinon.SinonSpy;

beforeEach(() => {
client = new ClientRMQ({});
msg = { pattern, data: 'data' };
connectSpy = sinon.spy(client, 'connect');
eventSpy = sinon.spy();
sendToQueueSpy = sinon.spy();
sendToQueueStub = sinon.stub().callsFake(() => ({ catch: sinon.spy() }));

client['channel'] = {
sendToQueue: sendToQueueSpy,
sendToQueue: sendToQueueStub,
};
client['responseEmitter'] = new EventEmitter();
client['responseEmitter'].on(pattern, eventSpy);
Expand All @@ -196,15 +195,15 @@ describe('ClientRMQ', function () {

it('should send message to a proper queue', () => {
client['publish'](msg, () => {
expect(sendToQueueSpy.called).to.be.true;
expect(sendToQueueSpy.getCall(0).args[0]).to.be.eql(client['queue']);
expect(sendToQueueStub.called).to.be.true;
expect(sendToQueueStub.getCall(0).args[0]).to.be.eql(client['queue']);
});
});

it('should send buffer from stringified message', () => {
client['publish'](msg, () => {
expect(sendToQueueSpy.called).to.be.true;
expect(sendToQueueSpy.getCall(1).args[1]).to.be.eql(
expect(sendToQueueStub.called).to.be.true;
expect(sendToQueueStub.getCall(1).args[1]).to.be.eql(
Buffer.from(JSON.stringify(msg)),
);
});
Expand All @@ -231,7 +230,7 @@ describe('ClientRMQ', function () {
describe('headers', () => {
it('should not generate headers if none are configured', () => {
client['publish'](msg, () => {
expect(sendToQueueSpy.getCall(0).args[2].headers).to.be.undefined;
expect(sendToQueueStub.getCall(0).args[2].headers).to.be.undefined;
});
});

Expand All @@ -240,7 +239,7 @@ describe('ClientRMQ', function () {
msg.data = new RmqRecord('data', { headers: requestHeaders });

client['publish'](msg, () => {
expect(sendToQueueSpy.getCall(0).args[2].headers).to.eql(
expect(sendToQueueStub.getCall(0).args[2].headers).to.eql(
requestHeaders,
);
});
Expand All @@ -254,7 +253,7 @@ describe('ClientRMQ', function () {
msg.data = new RmqRecord('data', { headers: requestHeaders });

client['publish'](msg, () => {
expect(sendToQueueSpy.getCall(0).args[2].headers).to.eql({
expect(sendToQueueStub.getCall(0).args[2].headers).to.eql({
...staticHeaders,
...requestHeaders,
});
Expand All @@ -269,7 +268,7 @@ describe('ClientRMQ', function () {
msg.data = new RmqRecord('data', { headers: requestHeaders });

client['publish'](msg, () => {
expect(sendToQueueSpy.getCall(0).args[2].headers).to.eql(
expect(sendToQueueStub.getCall(0).args[2].headers).to.eql(
requestHeaders,
);
});
Expand Down

0 comments on commit 2ed509a

Please sign in to comment.