diff --git a/indexer/services/socks/__tests__/lib/subscriptions.test.ts b/indexer/services/socks/__tests__/lib/subscriptions.test.ts index 04fdf84587..3d31f00f1d 100644 --- a/indexer/services/socks/__tests__/lib/subscriptions.test.ts +++ b/indexer/services/socks/__tests__/lib/subscriptions.test.ts @@ -10,7 +10,6 @@ import { btcTicker, invalidChannel, invalidTicker } from '../constants'; import { axiosRequest } from '../../src/lib/axios'; import { AxiosSafeServerError, makeAxiosSafeServerError } from '@dydxprotocol-indexer/base'; import { BlockedError } from '../../src/lib/errors'; -import { isRestrictedCountry } from '@dydxprotocol-indexer/compliance'; jest.mock('ws'); jest.mock('../../src/helpers/wss'); @@ -58,8 +57,7 @@ describe('Subscriptions', () => { [Channel.V4_TRADES]: ['/v4/trades/perpetualMarket/.+'], }; const initialMessage: Object = { a: 'b' }; - const restrictedCountry: string = 'US'; - const nonRestrictedCountry: string = 'AR'; + const country: string = 'AR'; beforeAll(async () => { await dbHelpers.migrate(); @@ -83,9 +81,6 @@ describe('Subscriptions', () => { axiosRequestMock = (axiosRequest as jest.Mock); axiosRequestMock.mockClear(); axiosRequestMock.mockImplementation(() => (JSON.stringify(initialMessage))); - (isRestrictedCountry as jest.Mock).mockImplementation((country: string): boolean => { - return country === restrictedCountry; - }); }); describe('subscribe', () => { @@ -106,7 +101,7 @@ describe('Subscriptions', () => { initialMsgId, id, false, - nonRestrictedCountry, + country, ); expect(sendMessageStringMock).toHaveBeenCalledTimes(1); @@ -126,6 +121,9 @@ describe('Subscriptions', () => { for (const urlPattern of urlPatterns) { expect(axiosRequestMock).toHaveBeenCalledWith(expect.objectContaining({ url: expect.stringMatching(RegExp(urlPattern)), + headers: { + 'cf-ipcountry': country, + }, })); } } else { @@ -150,7 +148,6 @@ describe('Subscriptions', () => { initialMsgId, id, false, - nonRestrictedCountry, ); expect(sendMessageMock).toHaveBeenCalledTimes(1); @@ -179,7 +176,6 @@ describe('Subscriptions', () => { initialMsgId, defaultId, false, - nonRestrictedCountry, ); }, ).rejects.toEqual(new Error(`Invalid channel: ${invalidChannel}`)); @@ -194,7 +190,6 @@ describe('Subscriptions', () => { initialMsgId, mockSubaccountId, false, - nonRestrictedCountry, ); expect(sendMessageMock).toHaveBeenCalledTimes(1); @@ -217,7 +212,6 @@ describe('Subscriptions', () => { initialMsgId, mockSubaccountId, false, - nonRestrictedCountry, ); expect(sendMessageMock).toHaveBeenCalledTimes(1); @@ -253,32 +247,7 @@ describe('Subscriptions', () => { initialMsgId, mockSubaccountId, false, - nonRestrictedCountry, - ); - - expect(sendMessageMock).toHaveBeenCalledTimes(1); - expect(sendMessageMock).toHaveBeenCalledWith( - mockWs, - connectionId, - expect.objectContaining({ - connection_id: connectionId, - type: 'error', - message: expectedError.message, - })); - expect(subscriptions.subscriptions[Channel.V4_ACCOUNTS]).toBeUndefined(); - expect(subscriptions.subscriptionLists[connectionId]).toBeUndefined(); - }); - - it('sends blocked error if subscribing to subaccount from restricted country', async () => { - const expectedError: BlockedError = new BlockedError(); - await subscriptions.subscribe( - mockWs, - Channel.V4_ACCOUNTS, - connectionId, - initialMsgId, - mockSubaccountId, - false, - restrictedCountry, + country, ); expect(sendMessageMock).toHaveBeenCalledTimes(1); @@ -305,7 +274,7 @@ describe('Subscriptions', () => { initialMsgId, mockSubaccountId, false, - nonRestrictedCountry, + country, ); expect(sendMessageStringMock).toHaveBeenCalledTimes(1); @@ -342,7 +311,7 @@ describe('Subscriptions', () => { initialMsgId, id, false, - nonRestrictedCountry, + country, ); subscriptions.unsubscribe( connectionId, @@ -362,7 +331,7 @@ describe('Subscriptions', () => { initialMsgId, mockSubaccountId, false, - nonRestrictedCountry, + country, ); subscriptions.unsubscribe( connectionId, @@ -386,7 +355,6 @@ describe('Subscriptions', () => { initialMsgId, validIds[channel], false, - nonRestrictedCountry, ); })); diff --git a/indexer/services/socks/__tests__/websocket/index.test.ts b/indexer/services/socks/__tests__/websocket/index.test.ts index 06dde78ca0..092f8dd664 100644 --- a/indexer/services/socks/__tests__/websocket/index.test.ts +++ b/indexer/services/socks/__tests__/websocket/index.test.ts @@ -14,15 +14,13 @@ import { } from '../../src/types'; import { InvalidMessageHandler } from '../../src/lib/invalid-message'; import { PingHandler } from '../../src/lib/ping'; -import config from '../../src/config'; -import { isRestrictedCountryHeaders, COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance'; +import { COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance'; jest.mock('uuid'); jest.mock('../../src/helpers/wss'); jest.mock('../../src/lib/subscription'); jest.mock('../../src/lib/invalid-message'); jest.mock('../../src/lib/ping'); -jest.mock('@dydxprotocol-indexer/compliance'); describe('Index', () => { let index: Index; @@ -32,12 +30,10 @@ describe('Index', () => { let mockConnect: (ws: WebSocket, req: IncomingMessage) => void; let wsOnSpy: jest.SpyInstance; let wsPingSpy: jest.SpyInstance; - let wsTerminateSpy: jest.SpyInstance; let invalidMsgHandlerSpy: jest.SpyInstance; let pingHandlerSpy: jest.SpyInstance; const connectionId: string = 'conId'; - const defaultGeoblockingEnabled: boolean = config.INDEXER_LEVEL_GEOBLOCKING_ENABLED; const countryCode: string = 'AR'; beforeAll(() => { @@ -58,7 +54,6 @@ describe('Index', () => { websocket = new WebSocket(null); wsOnSpy = jest.spyOn(websocket, 'on'); wsPingSpy = jest.spyOn(websocket, 'ping').mockImplementation(jest.fn()); - wsTerminateSpy = jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn()); mockWss.onConnection = jest.fn().mockImplementation( (cb: (ws: WebSocket, req: IncomingMessage) => void) => { mockConnect = cb; @@ -97,46 +92,6 @@ describe('Index', () => { }), ); }); - - describe('geoblocking', () => { - const isRestrictedCountrySpy: jest.Mock = isRestrictedCountryHeaders as unknown as jest.Mock; - - beforeAll(() => { - config.INDEXER_LEVEL_GEOBLOCKING_ENABLED = true; - }); - - afterAll(() => { - config.INDEXER_LEVEL_GEOBLOCKING_ENABLED = defaultGeoblockingEnabled; - }); - - it('rejects connection if from restricted country', () => { - jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn()); - // restricted country headers - isRestrictedCountrySpy.mockReturnValue(true); - - const message: IncomingMessage = new IncomingMessage(new Socket()); - mockConnect(websocket, message); - expect(websocket.terminate).toHaveBeenCalled(); - expect(Object.keys(index.connections)).toHaveLength(0); - expect(wsOnSpy).not.toHaveBeenCalled(); - expect(wsTerminateSpy).toHaveBeenCalled(); - expect(sendMessage).not.toHaveBeenCalled(); - }); - - it('does not reject connection if from restricted country', () => { - (v4 as unknown as jest.Mock).mockReturnValueOnce(connectionId); - // non-restricted country headers - isRestrictedCountrySpy.mockReturnValue(false); - - const message: IncomingMessage = new IncomingMessage(new Socket()); - mockConnect(websocket, message); - - // Test that the connection is tracked. - expect(index.connections[connectionId]).not.toBeUndefined(); - expect(index.connections[connectionId].ws).toEqual(websocket); - expect(index.connections[connectionId].messageId).toEqual(0); - }); - }); }); describe('handlers', () => { diff --git a/indexer/services/socks/src/helpers/header-utils.ts b/indexer/services/socks/src/helpers/header-utils.ts new file mode 100644 index 0000000000..0ce77bd07f --- /dev/null +++ b/indexer/services/socks/src/helpers/header-utils.ts @@ -0,0 +1,8 @@ +import { CountryHeaders } from '@dydxprotocol-indexer/compliance'; + +import { IncomingMessage } from '../types'; + +export function getCountry(req: IncomingMessage): string | undefined { + const countryHeaders: CountryHeaders = req.headers as CountryHeaders; + return countryHeaders['cf-ipcountry']; +} diff --git a/indexer/services/socks/src/lib/subscription.ts b/indexer/services/socks/src/lib/subscription.ts index 02aaff590a..7940dd3f33 100644 --- a/indexer/services/socks/src/lib/subscription.ts +++ b/indexer/services/socks/src/lib/subscription.ts @@ -3,7 +3,6 @@ import { logger, stats, } from '@dydxprotocol-indexer/base'; -import { isRestrictedCountry } from '@dydxprotocol-indexer/compliance'; import { CandleResolution, perpetualMarketRefresher } from '@dydxprotocol-indexer/postgres'; import WebSocket from 'ws'; @@ -491,13 +490,6 @@ export class Subscriptions { throw new Error('Invalid undefined id'); } - // TODO(IND-508): Change this to match technical spec for persistent geo-blocking. This may - // either have to replicate any blocking logic added on comlink, or re-direct to comlink to - // determine if subscribing to a specific subaccount is blocked. - if (country !== undefined && isRestrictedCountry(country)) { - throw new BlockedError(); - } - try { const { address, @@ -518,6 +510,9 @@ export class Subscriptions { method: RequestMethod.GET, url: `${COMLINK_URL}/v4/addresses/${address}/subaccountNumber/${subaccountNumber}`, timeout: config.INITIAL_GET_TIMEOUT_MS, + headers: { + 'cf-ipcountry': country, + }, transformResponse: (res) => res, }), // TODO(DEC-1462): Use the /active-orders endpoint once it's added. @@ -525,6 +520,9 @@ export class Subscriptions { method: RequestMethod.GET, url: `${COMLINK_URL}/v4/orders?address=${address}&subaccountNumber=${subaccountNumber}&status=OPEN,UNTRIGGERED,BEST_EFFORT_OPENED`, timeout: config.INITIAL_GET_TIMEOUT_MS, + headers: { + 'cf-ipcountry': country, + }, transformResponse: (res) => res, }), ]); @@ -597,6 +595,9 @@ export class Subscriptions { method: RequestMethod.GET, url: endpoint, timeout: config.INITIAL_GET_TIMEOUT_MS, + headers: { + 'cf-ipcountry': country, + }, transformResponse: (res) => res, // Disables JSON parsing }); } diff --git a/indexer/services/socks/src/websocket/index.ts b/indexer/services/socks/src/websocket/index.ts index 944e4243b9..003532c2ee 100644 --- a/indexer/services/socks/src/websocket/index.ts +++ b/indexer/services/socks/src/websocket/index.ts @@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid'; import WebSocket from 'ws'; import config from '../config'; +import { getCountry } from '../helpers/header-utils'; import { createErrorMessage, createConnectedMessage, @@ -26,7 +27,6 @@ import { ALL_CHANNELS, WebsocketEvents, } from '../types'; -import { CountryRestrictor } from './restrict-countries'; const HEARTBEAT_INTERVAL_MS: number = config.WS_HEARTBEAT_INTERVAL_MS; const HEARTBEAT_TIMEOUT_MS: number = config.WS_HEARTBEAT_TIMEOUT_MS; @@ -42,7 +42,6 @@ export class Index { // Handlers for pings and invalid messages. private pingHandler: PingHandler; private invalidMessageHandler: InvalidMessageHandler; - private countryRestrictor: CountryRestrictor; constructor(wss: Wss, subscriptions: Subscriptions) { this.wss = wss; @@ -50,7 +49,6 @@ export class Index { this.subscriptions = subscriptions; this.pingHandler = new PingHandler(); this.invalidMessageHandler = new InvalidMessageHandler(); - this.countryRestrictor = new CountryRestrictor(); // Attach the new connection handler to the websocket server. this.wss.onConnection((ws: WebSocket, req: IncomingMessage) => this.onConnection(ws, req)); @@ -99,17 +97,13 @@ export class Index { * @param req HTTP request accompanying new connection request. */ private onConnection(ws: WebSocket, req: IncomingMessage): void { - // Terminate the connection if the connection requestion originated from a restricted country - if (this.countryRestrictor.isRestrictedCountry(req)) { - return ws.terminate(); - } const connectionId: string = uuidv4(); this.connections[connectionId] = { ws, messageId: 0, - countryCode: this.countryRestrictor.getCountry(req), + countryCode: getCountry(req), }; const numConcurrentConnections: number = Object.keys(this.connections).length; diff --git a/indexer/services/socks/src/websocket/restrict-countries.ts b/indexer/services/socks/src/websocket/restrict-countries.ts deleted file mode 100644 index f079c90660..0000000000 --- a/indexer/services/socks/src/websocket/restrict-countries.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { - CountryHeaders, - isRestrictedCountryHeaders, -} from '@dydxprotocol-indexer/compliance'; - -import { IncomingMessage } from '../types'; - -export class CountryRestrictor { - public isRestrictedCountry(req: IncomingMessage): boolean { - if (isRestrictedCountryHeaders(req.headers as CountryHeaders)) { - return true; - } - - return false; - } - - public getCountry(req: IncomingMessage): string | undefined { - const countryHeaders: CountryHeaders = req.headers as CountryHeaders; - return countryHeaders['cf-ipcountry']; - } -}