diff --git a/services/apps/integration_stream_worker/src/service/integrationStreamService.ts b/services/apps/integration_stream_worker/src/service/integrationStreamService.ts index b5b7e21a88..e869bbdef6 100644 --- a/services/apps/integration_stream_worker/src/service/integrationStreamService.ts +++ b/services/apps/integration_stream_worker/src/service/integrationStreamService.ts @@ -2,7 +2,7 @@ import { addSeconds, singleOrDefault } from '@crowd/common' import { DbStore } from '@crowd/database' import { INTEGRATION_SERVICES, IProcessStreamContext } from '@crowd/integrations' import { Logger, LoggerBase, getChildLogger } from '@crowd/logging' -import { RedisCache, RedisClient } from '@crowd/redis' +import { RedisCache, RedisClient, RateLimiter } from '@crowd/redis' import { IntegrationDataWorkerEmitter, IntegrationRunWorkerEmitter, @@ -181,6 +181,8 @@ export default class IntegrationStreamService extends LoggerBase { this.log, ) + const globalCache = new RedisCache(`int-global`, this.redisClient, this.log) + const nangoConfig = NANGO_CONFIG() const context: IProcessStreamContext = { @@ -209,6 +211,7 @@ export default class IntegrationStreamService extends LoggerBase { log: this.log, cache, + globalCache, publishData: async (data) => { await this.publishData( @@ -241,6 +244,9 @@ export default class IntegrationStreamService extends LoggerBase { this.log.error({ message }, 'Aborting run with error!') await this.triggerRunError(streamInfo.runId, 'stream-run-abort', message, metadata, error) }, + getRateLimiter: (maxRequests: number, timeWindowSeconds: number, counterKey: string) => { + return new RateLimiter(globalCache, maxRequests, timeWindowSeconds, counterKey) + }, } this.log.debug('Marking stream as in progress!') diff --git a/services/libs/integrations/package-lock.json b/services/libs/integrations/package-lock.json index 01fcc8c315..fcc9fd6d52 100644 --- a/services/libs/integrations/package-lock.json +++ b/services/libs/integrations/package-lock.json @@ -66,6 +66,27 @@ "typescript": "^5.0.4" } }, + "../redis": { + "name": "@crowd/redis", + "version": "1.0.0", + "extraneous": true, + "dependencies": { + "@crowd/common": "file:../common", + "@crowd/logging": "file:../logging", + "@crowd/types": "file:../types", + "redis": "^4.6.6" + }, + "devDependencies": { + "@types/node": "^18.16.3", + "@typescript-eslint/eslint-plugin": "^5.59.2", + "@typescript-eslint/parser": "^5.59.2", + "eslint": "^8.39.0", + "eslint-config-prettier": "^8.8.0", + "eslint-plugin-prettier": "^4.2.1", + "prettier": "^2.8.8", + "typescript": "^5.0.4" + } + }, "../types": { "name": "@crowd/types", "version": "1.0.0", diff --git a/services/libs/integrations/src/integrations/reddit/api/getComments.ts b/services/libs/integrations/src/integrations/reddit/api/getComments.ts index 1497eed46e..800ef3a2c3 100644 --- a/services/libs/integrations/src/integrations/reddit/api/getComments.ts +++ b/services/libs/integrations/src/integrations/reddit/api/getComments.ts @@ -4,6 +4,7 @@ import { IProcessStreamContext } from '@/types' import { PlatformType } from '@crowd/types' import { RedditGetCommentsInput, RedditCommentsResponse } from '../types' import { timeout } from '@crowd/common' +import { getRateLimiter } from './handleRateLimit' /** * Get the comment tree of a post. @@ -16,12 +17,17 @@ async function getComments( ctx: IProcessStreamContext, ): Promise { try { + const rateLimiter = getRateLimiter(ctx) + ctx.log.info({ message: 'Fetching comments from a post in a sub-reddit', input }) // Wait for 1.5s for rate limits. // eslint-disable-next-line no-promise-executor-return await timeout(1500) + // Check if we can make a request - if not, it will throw a RateLimitError + await rateLimiter.checkRateLimit('getComments') + // Gett an access token from Nango const accessToken = await getNangoToken(input.nangoId, PlatformType.REDDIT, ctx) @@ -36,6 +42,9 @@ async function getComments( }, } + // we are going to make a request, so increment the rate limit + await rateLimiter.incrementRateLimit() + const response: RedditCommentsResponse = (await axios(config)).data return response } catch (err) { diff --git a/services/libs/integrations/src/integrations/reddit/api/getMoreComments.ts b/services/libs/integrations/src/integrations/reddit/api/getMoreComments.ts index 30f8b4bf8e..f2b1183c78 100644 --- a/services/libs/integrations/src/integrations/reddit/api/getMoreComments.ts +++ b/services/libs/integrations/src/integrations/reddit/api/getMoreComments.ts @@ -4,6 +4,7 @@ import { IProcessStreamContext } from '@/types' import { PlatformType } from '@crowd/types' import { RedditMoreCommentsInput, RedditMoreCommentsResponse } from '../types' import { timeout } from '@crowd/common' +import { getRateLimiter } from './handleRateLimit' /** * Expand a list of comment IDs into a comment tree. @@ -17,12 +18,17 @@ async function getMoreComments( ctx: IProcessStreamContext, ): Promise { try { + const rateLimiter = getRateLimiter(ctx) + ctx.log.info({ message: 'Fetching more comments from a sub-reddit', input }) // Wait for 1.5s for rate limits. // eslint-disable-next-line no-promise-executor-return await timeout(1500) + // Check if we can make a request - if not, it will throw a RateLimitError + await rateLimiter.checkRateLimit('getMoreComments') + // Gett an access token from Nango const accessToken = await getNangoToken(input.nangoId, PlatformType.REDDIT, ctx) @@ -39,6 +45,9 @@ async function getMoreComments( }, } + // we are going to make a request, so increment the rate limit + await rateLimiter.incrementRateLimit() + const response: RedditMoreCommentsResponse = (await axios(config)).data return response } catch (err) { diff --git a/services/libs/integrations/src/integrations/reddit/api/getPosts.ts b/services/libs/integrations/src/integrations/reddit/api/getPosts.ts index ff29d4745a..52548cab89 100644 --- a/services/libs/integrations/src/integrations/reddit/api/getPosts.ts +++ b/services/libs/integrations/src/integrations/reddit/api/getPosts.ts @@ -4,6 +4,7 @@ import { IProcessStreamContext } from '@/types' import { PlatformType } from '@crowd/types' import { RedditGetPostsInput, RedditPostsResponse, REDDIT_MAX_RETROSPECT_IN_HOURS } from '../types' import { timeout } from '@crowd/common' +import { getRateLimiter } from './handleRateLimit' /** * Get paginated posts from a subreddit @@ -16,12 +17,17 @@ async function getPosts( ctx: IProcessStreamContext, ): Promise { try { + const rateLimiter = getRateLimiter(ctx) + ctx.log.info({ message: 'Fetching posts from a sub-reddit', input }) // Wait for 1.5s for rate limits. // eslint-disable-next-line no-promise-executor-return await timeout(1500) + // Check if we can make a request - if not, it will throw a RateLimitError + await rateLimiter.checkRateLimit('getPosts') + // Gett an access token from Nango const accessToken = await getNangoToken(input.nangoId, PlatformType.REDDIT, ctx) @@ -41,6 +47,9 @@ async function getPosts( config.params.after = input.after } + // we are going to make a request, so increment the rate limit + await rateLimiter.incrementRateLimit() + const response: RedditPostsResponse = (await axios(config)).data // If ctx.onboarding is false, check the last post's date diff --git a/services/libs/integrations/src/integrations/reddit/api/handleRateLimit.ts b/services/libs/integrations/src/integrations/reddit/api/handleRateLimit.ts new file mode 100644 index 0000000000..363ec9837b --- /dev/null +++ b/services/libs/integrations/src/integrations/reddit/api/handleRateLimit.ts @@ -0,0 +1,9 @@ +import { IProcessStreamContext } from '@/types' + +const REDDIT_RATE_LIMIT = 100 +const REDDIT_RATE_LIMIT_TIME = 60 +const REDIS_KEY = 'reddit-request-count' + +export const getRateLimiter = (ctx: IProcessStreamContext) => { + return ctx.getRateLimiter(REDDIT_RATE_LIMIT, REDDIT_RATE_LIMIT_TIME, REDIS_KEY) +} diff --git a/services/libs/integrations/src/types.ts b/services/libs/integrations/src/types.ts index 7c406ae742..a8fb5a3a48 100644 --- a/services/libs/integrations/src/types.ts +++ b/services/libs/integrations/src/types.ts @@ -1,6 +1,6 @@ import { IMemberAttribute, IActivityData } from '@crowd/types' import { Logger } from '@crowd/logging' -import { ICache, IIntegration, IIntegrationStream } from '@crowd/types' +import { ICache, IIntegration, IIntegrationStream, IRateLimiter } from '@crowd/types' export interface IIntegrationContext { onboarding: boolean @@ -27,6 +27,10 @@ export interface IProcessStreamContext extends IIntegrationContext { publishData: (data: T) => Promise abortWithError: (message: string, metadata?: unknown, error?: Error) => Promise + + globalCache: ICache + + getRateLimiter: (maxRequests: number, timeWindowSeconds: number, cacheKey: string) => IRateLimiter } export interface IProcessDataContext extends IIntegrationContext { diff --git a/services/libs/redis/src/cache.ts b/services/libs/redis/src/cache.ts index a3a4881e21..3b5a1eda87 100644 --- a/services/libs/redis/src/cache.ts +++ b/services/libs/redis/src/cache.ts @@ -42,6 +42,22 @@ export class RedisCache extends LoggerBase implements ICache { } } + async increment(key: string, incrementBy = 1, ttlSeconds?: number): Promise { + const actualKey = this.prefixer(key) + + if (ttlSeconds !== undefined) { + const [incrResult] = await this.client + .multi() + .incrBy(actualKey, incrementBy) + .expire(actualKey, ttlSeconds) + .exec() + return incrResult as number + } + + const result = await this.client.incrBy(actualKey, incrementBy) + return result + } + public setIfNotExistsAlready(key: string, value: string): Promise { const actualKey = this.prefixer(key) return this.client.setNX(actualKey, value) diff --git a/services/libs/redis/src/index.ts b/services/libs/redis/src/index.ts index 88161dd3eb..c1a73aa123 100644 --- a/services/libs/redis/src/index.ts +++ b/services/libs/redis/src/index.ts @@ -4,3 +4,4 @@ export * from './pubsub' export * from './cache' export * from './instances' +export * from './rateLimiter' diff --git a/services/libs/redis/src/rateLimiter.ts b/services/libs/redis/src/rateLimiter.ts new file mode 100644 index 0000000000..9fbc2965e6 --- /dev/null +++ b/services/libs/redis/src/rateLimiter.ts @@ -0,0 +1,30 @@ +import { ICache, IRateLimiter, RateLimitError } from '@crowd/types' + +export class RateLimiter implements IRateLimiter { + constructor( + private readonly cache: ICache, + private readonly maxRequests: number, + private readonly timeWindowSeconds: number, + private readonly counterKey: string, + ) { + this.cache = cache + this.maxRequests = maxRequests + this.timeWindowSeconds = timeWindowSeconds + this.counterKey = counterKey + } + + public async checkRateLimit(endpoint: string) { + const value = await this.cache.get(this.counterKey) + const requestCount = value === null ? 0 : parseInt(value) + const canMakeRequest = requestCount < this.maxRequests + + if (!canMakeRequest) { + const sleepTime = this.timeWindowSeconds + Math.floor(Math.random() * this.maxRequests) + throw new RateLimitError(sleepTime, endpoint) + } + } + + public async incrementRateLimit() { + await this.cache.increment(this.counterKey, 1, this.timeWindowSeconds) + } +} diff --git a/services/libs/types/src/caching.ts b/services/libs/types/src/caching.ts index cbd2350d4a..dc4066a139 100644 --- a/services/libs/types/src/caching.ts +++ b/services/libs/types/src/caching.ts @@ -2,4 +2,10 @@ export interface ICache { get(key: string): Promise set(key: string, value: string, ttlSeconds: number): Promise delete(key: string): Promise + increment(key: string, incrementBy?: number, ttlSeconds?: number): Promise +} + +export interface IRateLimiter { + checkRateLimit(endpoint: string): Promise + incrementRateLimit(): Promise }