diff --git a/packages/cubejs-api-gateway/package.json b/packages/cubejs-api-gateway/package.json index f6e1f5429b0b6..197bbc3d4b965 100644 --- a/packages/cubejs-api-gateway/package.json +++ b/packages/cubejs-api-gateway/package.json @@ -15,7 +15,7 @@ "typings": "dist/src/index.d.ts", "scripts": { "test": "npm run unit", - "unit": "jest --coverage dist/test", + "unit": "CUBE_JS_NATIVE_API_GATEWAY_INTERNAL=true jest --coverage --forceExit dist/test", "build": "rm -rf dist && npm run tsc", "tsc": "tsc", "watch": "tsc -w", diff --git a/packages/cubejs-api-gateway/src/SubscriptionServer.ts b/packages/cubejs-api-gateway/src/SubscriptionServer.ts index c45c302915d96..b6df6d501ae9c 100644 --- a/packages/cubejs-api-gateway/src/SubscriptionServer.ts +++ b/packages/cubejs-api-gateway/src/SubscriptionServer.ts @@ -53,7 +53,7 @@ export class SubscriptionServer { } if (message.authorization) { - authContext = { isSubscription: true }; + authContext = { isSubscription: true, protocol: 'ws' }; await this.apiGateway.checkAuthFn(authContext, message.authorization); const acceptanceResult = await this.contextAcceptor(authContext); if (!acceptanceResult.accepted) { diff --git a/packages/cubejs-api-gateway/src/gateway.ts b/packages/cubejs-api-gateway/src/gateway.ts index d7180ff3fc6af..28972b895a9fe 100644 --- a/packages/cubejs-api-gateway/src/gateway.ts +++ b/packages/cubejs-api-gateway/src/gateway.ts @@ -120,8 +120,10 @@ function systemAsyncHandler(handler: (req: Request & { context: ExtendedRequestC }; } -// Prepared CheckAuthFn, default or from config: always async, returns nothing -type PreparedCheckAuthFn = (ctx: any, authorization?: string) => Promise; +// Prepared CheckAuthFn, default or from config: always async +type PreparedCheckAuthFn = (ctx: any, authorization?: string) => Promise<{ + securityContext: any; +}>; class ApiGateway { protected readonly refreshScheduler: any; @@ -148,9 +150,9 @@ class ApiGateway { public readonly checkAuthSystemFn: PreparedCheckAuthFn; - protected readonly contextToApiScopesFn: ContextToApiScopesFn; + public readonly contextToApiScopesFn: ContextToApiScopesFn; - protected readonly contextToApiScopesDefFn: ContextToApiScopesFn = + public readonly contextToApiScopesDefFn: ContextToApiScopesFn = async () => ['graphql', 'meta', 'data']; protected readonly requestLoggerMiddleware: RequestLoggerMiddlewareFn; @@ -544,20 +546,24 @@ class ApiGateway { } if (getEnv('nativeApiGateway')) { - const proxyMiddleware = createProxyMiddleware({ - target: `http://127.0.0.1:${this.sqlServer.getNativeGatewayPort()}/v2`, - changeOrigin: true, - }); - - app.use( - `${this.basePath}/v2`, - proxyMiddleware as any - ); + this.enableNativeApiGateway(app); } app.use(this.handleErrorMiddleware); } + protected enableNativeApiGateway(app: ExpressApplication) { + const proxyMiddleware = createProxyMiddleware({ + target: `http://127.0.0.1:${this.sqlServer.getNativeGatewayPort()}/v2`, + changeOrigin: true, + }); + + app.use( + `${this.basePath}/v2`, + proxyMiddleware as any + ); + } + public initSubscriptionServer(sendMessage: WebSocketSendMessageFn) { return new SubscriptionServer(this, sendMessage, this.subscriptionStore, this.wsContextAcceptor); } @@ -2250,6 +2256,10 @@ class ApiGateway { showWarningAboutNotObject = true; } + + return { + securityContext: req.securityContext + }; }; } @@ -2333,6 +2343,10 @@ class ApiGateway { // @todo Move it to 401 or 400 throw new CubejsHandlerError(403, 'Forbidden', 'Authorization header isn\'t set'); } + + return { + securityContext: req.securityContext + }; }; } @@ -2343,6 +2357,7 @@ class ApiGateway { if (this.playgroundAuthSecret) { const systemCheckAuthFn = this.createCheckAuthSystemFn(); + return async (ctx, authorization) => { // TODO: separate two auth workflows try { @@ -2354,6 +2369,10 @@ class ApiGateway { throw mainAuthError; } } + + return { + securityContext: ctx.securityContext, + }; }; } @@ -2371,6 +2390,10 @@ class ApiGateway { return async (ctx, authorization) => { await systemCheckAuthFn(ctx, authorization); + + return { + securityContext: ctx.securityContext + }; }; } diff --git a/packages/cubejs-api-gateway/src/sql-server.ts b/packages/cubejs-api-gateway/src/sql-server.ts index bce155d52e615..a8faa00cd8ae9 100644 --- a/packages/cubejs-api-gateway/src/sql-server.ts +++ b/packages/cubejs-api-gateway/src/sql-server.ts @@ -107,6 +107,17 @@ export class SQLServer { this.sqlInterfaceInstance = await registerInterface({ gatewayPort: this.gatewayPort, pgPort: options.pgSqlPort, + contextToApiScopes: async ({ securityContext }) => this.apiGateway.contextToApiScopesFn( + securityContext, + getEnv('defaultApiScope') || await this.apiGateway.contextToApiScopesDefFn() + ), + checkAuth: async ({ request, token }) => { + const { securityContext } = await this.apiGateway.checkAuthFn(request, token); + + return { + securityContext + }; + }, checkSqlAuth: async ({ request, user, password }) => { const { password: returnedPassword, superuser, securityContext, skipPasswordCheck } = await checkSqlAuth(request, user, password); diff --git a/packages/cubejs-api-gateway/test/auth.test.ts b/packages/cubejs-api-gateway/test/auth.test.ts index fe52541d2b172..5ee4f47903483 100644 --- a/packages/cubejs-api-gateway/test/auth.test.ts +++ b/packages/cubejs-api-gateway/test/auth.test.ts @@ -4,21 +4,49 @@ import express, { Application as ExpressApplication, RequestHandler } from 'expr import request from 'supertest'; import jwt from 'jsonwebtoken'; import { pausePromise } from '@cubejs-backend/shared'; +import { resetLogger } from '@cubejs-backend/native'; -import { ApiGateway, ApiGatewayOptions, CubejsHandlerError, Request } from '../src'; +import { ApiGateway, ApiGatewayOptions, CubejsHandlerError, Request, RequestContext } from '../src'; import { AdapterApiMock, DataSourceStorageMock } from './mocks'; -import { RequestContext } from '../src/interfaces'; import { generateAuthToken } from './utils'; +class ApiGatewayOpenAPI extends ApiGateway { + protected isRunning: Promise | null = null; + + public coerceForSqlQuery(query, context: RequestContext) { + return super.coerceForSqlQuery(query, context); + } + + public async startSQLServer(): Promise { + if (this.isRunning) { + return this.isRunning; + } + + this.isRunning = this.sqlServer.init({}); + + return this.isRunning; + } + + public async shutdownSQLServer(): Promise { + try { + await this.sqlServer.shutdown('fast'); + } finally { + this.isRunning = null; + } + + // SQLServer changes logger for rust side with setupLogger in the constructor, but it leads + // to a memory leak, that's why jest doesn't allow to shut down tests + resetLogger( + process.env.CUBEJS_LOG_LEVEL === 'trace' ? 'trace' : 'warn' + ); + } +} + function createApiGateway(handler: RequestHandler, logger: () => any, options: Partial) { const adapterApi: any = new AdapterApiMock(); const dataSourceStorage: any = new DataSourceStorageMock(); - class ApiGatewayFake extends ApiGateway { - public coerceForSqlQuery(query, context: RequestContext) { - return super.coerceForSqlQuery(query, context); - } - + class ApiGatewayFake extends ApiGatewayOpenAPI { public initApp(app: ExpressApplication) { const userMiddlewares: RequestHandler[] = [ this.checkAuth, @@ -26,6 +54,7 @@ function createApiGateway(handler: RequestHandler, logger: () => any, options: P ]; app.get('/test-auth-fake', userMiddlewares, handler); + this.enableNativeApiGateway(app); app.use(this.handleErrorMiddleware); } @@ -41,6 +70,7 @@ function createApiGateway(handler: RequestHandler, logger: () => any, options: P }); process.env.NODE_ENV = 'unknown'; + const app = express(); apiGateway.initApp(app); @@ -50,6 +80,119 @@ function createApiGateway(handler: RequestHandler, logger: () => any, options: P }; } +describe('test authorization with native gateway', () => { + let app: ExpressApplication; + let apiGateway: ApiGatewayOpenAPI; + + const handlerMock = jest.fn(() => { + // nothing, we are using it to verify that we don't got to express code + }); + const loggerMock = jest.fn(() => { + // + }); + const checkAuthMock = jest.fn((req, token) => { + jwt.verify(token, 'secret'); + + return { + security_context: {} + }; + }); + + beforeAll(async () => { + const result = createApiGateway(handlerMock, loggerMock, { + checkAuth: checkAuthMock, + gatewayPort: 8585, + }); + + app = result.app; + apiGateway = result.apiGateway; + + await result.apiGateway.startSQLServer(); + }); + + beforeEach(() => { + handlerMock.mockClear(); + loggerMock.mockClear(); + checkAuthMock.mockClear(); + }); + + afterAll(async () => { + await apiGateway.shutdownSQLServer(); + }); + + it('default authorization - success', async () => { + const token = generateAuthToken({ uid: 5, }); + + await request(app) + .get('/cubejs-api/v2/stream') + .set('Authorization', `${token}`) + .send() + .expect(501); + + // No bad logs + expect(loggerMock.mock.calls.length).toEqual(0); + // We should not call js handler, request should go into rust code + expect(handlerMock.mock.calls.length).toEqual(0); + + // Verify that we passed token to JS side + expect(checkAuthMock.mock.calls.length).toEqual(1); + expect(checkAuthMock.mock.calls[0][0].protocol).toEqual('http'); + expect(checkAuthMock.mock.calls[0][1]).toEqual(token); + }); + + it('default authorization - success (bearer prefix)', async () => { + const token = generateAuthToken({ uid: 5, }); + + await request(app) + .get('/cubejs-api/v2/stream') + .set('Authorization', `Bearer ${token}`) + .send() + .expect(501); + + // No bad logs + expect(loggerMock.mock.calls.length).toEqual(0); + // We should not call js handler, request should go into rust code + expect(handlerMock.mock.calls.length).toEqual(0); + + // Verify that we passed token to JS side + expect(checkAuthMock.mock.calls.length).toEqual(1); + expect(checkAuthMock.mock.calls[0][0].protocol).toEqual('http'); + expect(checkAuthMock.mock.calls[0][1]).toEqual(token); + }); + + it('default authorization - wrong secret', async () => { + const badToken = 'SUPER_LARGE_BAD_TOKEN_WHICH_IS_NOT_A_TOKEN'; + + await request(app) + .get('/cubejs-api/v2/stream') + .set('Authorization', `${badToken}`) + .send() + .expect(401); + + // No bad logs + expect(loggerMock.mock.calls.length).toEqual(0); + // We should not call js handler, request should go into rust code + expect(handlerMock.mock.calls.length).toEqual(0); + + // Verify that we passed token to JS side + expect(checkAuthMock.mock.calls.length).toEqual(1); + expect(checkAuthMock.mock.calls[0][0].protocol).toEqual('http'); + expect(checkAuthMock.mock.calls[0][1]).toEqual(badToken); + }); + + it('default authorization - missing auth header', async () => { + await request(app) + .get('/cubejs-api/v2/stream') + .send() + .expect(401); + + // No bad logs + expect(loggerMock.mock.calls.length).toEqual(0); + // We should not call js handler, request should go into rust code + expect(handlerMock.mock.calls.length).toEqual(0); + }); +}); + describe('test authorization', () => { test('default authorization', async () => { const loggerMock = jest.fn(() => { diff --git a/packages/cubejs-backend-native/Cargo.lock b/packages/cubejs-backend-native/Cargo.lock index f6a82ef21e4ea..d0d4541edcd30 100644 --- a/packages/cubejs-backend-native/Cargo.lock +++ b/packages/cubejs-backend-native/Cargo.lock @@ -209,13 +209,13 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.7.5" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", - "base64 0.21.5", + "base64 0.22.1", "bytes", "futures-util", "http", @@ -235,10 +235,10 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sha1", - "sync_wrapper 1.0.1", + "sync_wrapper", "tokio", "tokio-tungstenite", - "tower", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" dependencies = [ "async-trait", "bytes", @@ -259,7 +259,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 0.1.2", + "sync_wrapper", "tower-layer", "tower-service", "tracing", @@ -466,9 +466,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" @@ -747,7 +747,6 @@ dependencies = [ "async-channel", "async-trait", "axum", - "bytes", "convert_case 0.6.0", "cubenativeutils", "cubeorchestrator", @@ -762,10 +761,8 @@ dependencies = [ "minijinja", "neon", "once_cell", - "pin-project", "pyo3", "pyo3-asyncio", - "regex", "serde", "serde_json", "simple_logger", @@ -1507,7 +1504,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -2826,7 +2823,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.1", + "sync_wrapper", "tokio", "tokio-rustls", "tower-service", @@ -3338,12 +3335,6 @@ dependencies = [ "syn 2.0.98", ] -[[package]] -name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - [[package]] name = "sync_wrapper" version = "1.0.1" @@ -3581,9 +3572,9 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.21.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" dependencies = [ "futures-util", "log", @@ -3636,20 +3627,35 @@ dependencies = [ "tokio", "tower-layer", "tower-service", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", "tracing", ] [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -3691,9 +3697,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.21.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" dependencies = [ "byteorder", "bytes", @@ -3704,7 +3710,6 @@ dependencies = [ "rand", "sha1", "thiserror 1.0.69", - "url", "utf-8", ] diff --git a/packages/cubejs-backend-native/Cargo.toml b/packages/cubejs-backend-native/Cargo.toml index ba826fda486ff..ec59d90f8250e 100644 --- a/packages/cubejs-backend-native/Cargo.toml +++ b/packages/cubejs-backend-native/Cargo.toml @@ -24,11 +24,10 @@ anyhow = "1.0" async-channel = { version = "2" } async-trait = "0.1.36" convert_case = "0.6.0" -pin-project = "1.1.5" findshlibs = "0.10.2" futures = "0.3.30" http-body-util = "0.1" -axum = { version = "0.7.5", features = ["default", "ws"] } +axum = { version = "0.7.9", features = ["default", "ws"] } libc = "0.2" log = "0.4.21" log-reroute = "0.1" @@ -45,8 +44,6 @@ serde_json = "1.0.127" simple_logger = "1.7.0" tokio = { version = "1", features = ["full", "rt"] } uuid = { version = "1", features = ["v4"] } -bytes = "1.5.0" -regex = "1.10.2" [dependencies.neon] version = "=1" diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index 871d0c4ede570..077c3de15986f 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -27,13 +27,30 @@ export interface Request { } export interface CheckAuthResponse { + securityContext: any, +} + +export interface CheckSQLAuthResponse { password: string | null, superuser: boolean, securityContext: any, skipPasswordCheck?: boolean, } +export interface ContextToApiScopesPayload { + securityContext: any, +} + +export type ContextToApiScopesResponse = string[]; + +export interface CheckAuthPayloadRequestMeta extends BaseMeta {} + export interface CheckAuthPayload { + request: Request, + token: string, +} + +export interface CheckSQLAuthPayload { request: Request, user: string | null, password: string | null, @@ -88,7 +105,9 @@ export interface CanSwitchUserPayload { export type SQLInterfaceOptions = { pgPort?: number, - checkSqlAuth: (payload: CheckAuthPayload) => CheckAuthResponse | Promise, + contextToApiScopes: (payload: ContextToApiScopesPayload) => ContextToApiScopesResponse | Promise, + checkAuth: (payload: CheckAuthPayload) => CheckAuthResponse | Promise, + checkSqlAuth: (payload: CheckSQLAuthPayload) => CheckSQLAuthResponse | Promise, load: (payload: LoadPayload) => unknown | Promise, sql: (payload: SqlPayload) => unknown | Promise, meta: (payload: MetaPayload) => unknown | Promise, @@ -335,6 +354,12 @@ export const setupLogger = (logger: (extra: any) => unknown, logLevel: LogLevel) native.setupLogger({ logger: wrapNativeFunctionWithChannelCallback(logger), logLevel }); }; +/// Reset local to default implementation, which uses STDOUT +export const resetLogger = (logLevel: LogLevel): void => { + const native = loadNative(); + native.resetLogger({ logLevel }); +}; + export const isFallbackBuild = (): boolean => { const native = loadNative(); return native.isFallbackBuild(); @@ -347,6 +372,14 @@ export const registerInterface = async (options: SQLInterfaceOptions): Promise, + check_auth: Arc>, check_sql_auth: Arc>, + context_to_api_scopes: Arc>, +} + +pub struct NodeBridgeAuthServiceOptions { + pub check_auth: Root, + pub check_sql_auth: Root, + pub context_to_api_scopes: Root, } impl NodeBridgeAuthService { - pub fn new(channel: Channel, check_sql_auth: Root) -> Self { + pub fn new(channel: Channel, options: NodeBridgeAuthServiceOptions) -> Self { Self { channel: Arc::new(channel), - check_sql_auth: Arc::new(check_sql_auth), + check_auth: Arc::new(options.check_auth), + check_sql_auth: Arc::new(options.check_sql_auth), + context_to_api_scopes: Arc::new(options.context_to_api_scopes), } } } @@ -36,14 +51,14 @@ pub struct TransportRequest { } #[derive(Debug, Serialize)] -struct CheckSQLAuthRequest { +struct CheckSQLAuthTransportRequest { request: TransportRequest, user: Option, password: Option, } #[derive(Debug, Deserialize)] -struct CheckSQLAuthResponse { +struct CheckSQLAuthTransportResponse { password: Option, superuser: bool, #[serde(rename = "securityContext", skip_serializing_if = "Option::is_none")] @@ -53,16 +68,24 @@ struct CheckSQLAuthResponse { } #[derive(Debug)] -pub struct NativeAuthContext { +pub struct NativeSQLAuthContext { pub user: Option, pub superuser: bool, pub security_context: Option, } -impl AuthContext for NativeAuthContext { +impl AuthContext for NativeSQLAuthContext { fn as_any(&self) -> &dyn Any { self } + + fn user(&self) -> Option<&String> { + self.user.as_ref() + } + + fn security_context(&self) -> Option<&serde_json::Value> { + self.security_context.as_ref() + } } #[async_trait] @@ -72,11 +95,11 @@ impl SqlAuthService for NodeBridgeAuthService { user: Option, password: Option, ) -> Result { - trace!("[auth] Request ->"); + trace!("[sql auth] Request ->"); let request_id = Uuid::new_v4().to_string(); - let extra = serde_json::to_string(&CheckSQLAuthRequest { + let extra = serde_json::to_string(&CheckSQLAuthTransportRequest { request: TransportRequest { id: format!("{}-span-1", request_id), meta: None, @@ -84,16 +107,16 @@ impl SqlAuthService for NodeBridgeAuthService { user: user.clone(), password: password.clone(), })?; - let response: CheckSQLAuthResponse = call_js_with_channel_as_callback( + let response: CheckSQLAuthTransportResponse = call_js_with_channel_as_callback( self.channel.clone(), self.check_sql_auth.clone(), Some(extra), ) .await?; - trace!("[auth] Request <- {:?}", response); + trace!("[sql auth] Request <- {:?}", response); Ok(AuthenticateResponse { - context: Arc::new(NativeAuthContext { + context: Arc::new(NativeSQLAuthContext { user, superuser: response.superuser, security_context: response.security_context, @@ -104,4 +127,96 @@ impl SqlAuthService for NodeBridgeAuthService { } } -di_service!(NodeBridgeAuthService, [SqlAuthService]); +#[derive(Debug, Serialize)] +struct CheckAuthTransportRequest { + request: GatewayCheckAuthRequest, + token: String, +} + +#[derive(Debug, Deserialize)] +struct CheckAuthTransportResponse { + #[serde(rename = "securityContext", skip_serializing_if = "Option::is_none")] + security_context: Option, +} + +#[derive(Debug)] +pub struct NativeGatewayAuthContext { + pub security_context: Option, +} + +impl GatewayAuthContext for NativeGatewayAuthContext { + fn as_any(&self) -> &dyn Any { + self + } + + fn user(&self) -> Option<&String> { + None + } + + fn security_context(&self) -> Option<&serde_json::Value> { + self.security_context.as_ref() + } +} + +#[derive(Debug, Serialize)] +struct ContextToApiScopesTransportRequest<'ref_auth_context> { + security_context: &'ref_auth_context Option, +} + +type ContextToApiScopesTransportResponse = Vec; + +#[async_trait] +impl GatewayAuthService for NodeBridgeAuthService { + async fn authenticate( + &self, + request: GatewayCheckAuthRequest, + token: String, + ) -> Result { + trace!("[auth] Request ->"); + + let extra = serde_json::to_string(&CheckAuthTransportRequest { + request, + token: token.clone(), + })?; + let response: CheckAuthTransportResponse = call_js_with_channel_as_callback( + self.channel.clone(), + self.check_auth.clone(), + Some(extra), + ) + .await?; + trace!("[auth] Request <- {:?}", response); + + Ok(GatewayAuthenticateResponse { + context: Arc::new(NativeGatewayAuthContext { + security_context: response.security_context, + }), + }) + } + + async fn context_to_api_scopes( + &self, + auth_context: &GatewayAuthContextRef, + ) -> Result { + trace!("[context_to_api_scopes] Request ->"); + + let native_auth = auth_context + .as_any() + .downcast_ref::() + .expect("Unable to cast AuthContext to NativeGatewayAuthContext"); + + let extra = serde_json::to_string(&ContextToApiScopesTransportRequest { + security_context: &native_auth.security_context, + })?; + let response: ContextToApiScopesTransportResponse = call_js_with_channel_as_callback( + self.channel.clone(), + self.context_to_api_scopes.clone(), + Some(extra), + ) + .await?; + trace!("[context_to_api_scopes] Request <- {:?}", response); + + Ok(GatewayContextToApiScopesResponse { scopes: response }) + } +} + +di_service!(NodeBridgeAuthService, [SqlAuthService, GatewayAuthService]); diff --git a/packages/cubejs-backend-native/src/channel.rs b/packages/cubejs-backend-native/src/channel.rs index 607631c68c661..debcd89ac7b80 100644 --- a/packages/cubejs-backend-native/src/channel.rs +++ b/packages/cubejs-backend-native/src/channel.rs @@ -398,7 +398,9 @@ impl SqlGenerator for NodeSqlGenerator { impl Drop for NodeSqlGenerator { fn drop(&mut self) { let channel = self.channel.clone(); - let sql_generator_obj = self.sql_generator_obj.take().unwrap(); + // Safety: Safe, because on_track take is used only for dropping + let sql_generator_obj = self.sql_generator_obj.take().expect("Unable to take sql_generator_object while dropping NodeSqlGenerator, it was already taken"); + channel.send(move |mut cx| { let _ = match Arc::try_unwrap(sql_generator_obj) { Ok(v) => v.into_inner(&mut cx), diff --git a/packages/cubejs-backend-native/src/config.rs b/packages/cubejs-backend-native/src/config.rs index a15074cf9e20a..819b2a284e5b9 100644 --- a/packages/cubejs-backend-native/src/config.rs +++ b/packages/cubejs-backend-native/src/config.rs @@ -1,5 +1,7 @@ use crate::gateway::server::ApiGatewayServerImpl; -use crate::gateway::{ApiGatewayRouterBuilder, ApiGatewayServer}; +use crate::gateway::{ + ApiGatewayRouterBuilder, ApiGatewayServer, ApiGatewayState, GatewayAuthService, +}; use crate::{auth::NodeBridgeAuthService, transport::NodeBridgeTransport}; use async_trait::async_trait; use cubesql::config::injection::Injector; @@ -72,6 +74,7 @@ impl NodeCubeServices { .injector .get_service_typed::() .await; + gateway_server.stop_processing(shutdown_mode).await?; } @@ -154,8 +157,14 @@ impl NodeConfiguration for NodeConfigurationImpl { .register_typed::(|_| async move { transport }) .await; + let auth_to_move = auth.clone(); + injector + .register_typed::(|_| async move { auth_to_move }) + .await; + + let auth_to_move = auth.clone(); injector - .register_typed::(|_| async move { auth }) + .register_typed::(|_| async move { auth_to_move }) .await; if let Some(api_gateway_address) = &self.api_gateway_address { @@ -163,10 +172,12 @@ impl NodeConfiguration for NodeConfigurationImpl { injector .register_typed::(|i| async move { + let state = Arc::new(ApiGatewayState::new(i)); + ApiGatewayServerImpl::new( - ApiGatewayRouterBuilder::new(), + ApiGatewayRouterBuilder::new(state.clone()), api_gateway_address, - i.clone(), + state, ) }) .await; diff --git a/packages/cubejs-backend-native/src/cubesql_utils.rs b/packages/cubejs-backend-native/src/cubesql_utils.rs index 36a7e7fafa007..eb3a3f5ff6f25 100644 --- a/packages/cubejs-backend-native/src/cubesql_utils.rs +++ b/packages/cubejs-backend-native/src/cubesql_utils.rs @@ -8,12 +8,12 @@ use cubesql::config::ConfigObj; use cubesql::sql::{Session, SessionManager}; use cubesql::CubeError; -use crate::auth::NativeAuthContext; +use crate::auth::NativeSQLAuthContext; use crate::config::NodeCubeServices; pub async fn create_session( services: &NodeCubeServices, - native_auth_ctx: Arc, + native_auth_ctx: Arc, ) -> Result, CubeError> { let config = services .injector() @@ -53,7 +53,7 @@ pub async fn create_session( pub async fn with_session( services: &NodeCubeServices, - native_auth_ctx: Arc, + native_auth_ctx: Arc, f: F, ) -> Result where diff --git a/packages/cubejs-backend-native/src/gateway/auth_middleware.rs b/packages/cubejs-backend-native/src/gateway/auth_middleware.rs new file mode 100644 index 0000000000000..a4f1385ecaf8c --- /dev/null +++ b/packages/cubejs-backend-native/src/gateway/auth_middleware.rs @@ -0,0 +1,74 @@ +use crate::gateway::http_error::HttpError; +use crate::gateway::state::ApiGatewayStateRef; +use crate::gateway::{GatewayAuthContextRef, GatewayAuthService, GatewayCheckAuthRequest}; +use axum::extract::State; +use axum::http::HeaderValue; +use axum::response::IntoResponse; + +#[derive(Debug, Clone)] +pub struct AuthExtension { + auth_context: GatewayAuthContextRef, +} + +impl AuthExtension { + pub fn auth_context(&self) -> &GatewayAuthContextRef { + &self.auth_context + } +} + +fn parse_token(header_value: &HeaderValue) -> Result<&str, HttpError> { + let trimmed = header_value.to_str()?.trim(); + + let stripped = if let Some(stripped) = trimmed.strip_prefix("Bearer ") { + stripped + } else if let Some(stripped) = trimmed.strip_prefix("bearer ") { + stripped + } else { + trimmed + }; + + if stripped.is_empty() { + Err(HttpError::unauthorized( + "Value for authorization header cannot be empty".to_string(), + )) + } else { + Ok(stripped) + } +} + +pub async fn gateway_auth_middleware( + State(state): State, + mut req: axum::extract::Request, + next: axum::middleware::Next, +) -> Result { + let Some(token_header_value) = req.headers().get("authorization") else { + return Err(HttpError::unauthorized( + "No authorization header".to_string(), + )); + }; + + let bearer_token = parse_token(token_header_value)?; + + let auth = state + .injector_ref() + .get_service_typed::() + .await; + + let auth_fut = auth.authenticate( + GatewayCheckAuthRequest { + protocol: "http".to_string(), + }, + bearer_token.to_string(), + ); + + let auth_response = auth_fut + .await + .map_err(|_err| HttpError::unauthorized("Authentication error".to_string()))?; + + req.extensions_mut().insert(AuthExtension { + auth_context: auth_response.context, + }); + + let response = next.run(req).await; + Ok(response) +} diff --git a/packages/cubejs-backend-native/src/gateway/auth_service.rs b/packages/cubejs-backend-native/src/gateway/auth_service.rs new file mode 100644 index 0000000000000..d6458ea27b190 --- /dev/null +++ b/packages/cubejs-backend-native/src/gateway/auth_service.rs @@ -0,0 +1,46 @@ +use std::{any::Any, fmt::Debug, sync::Arc}; + +use crate::CubeError; +use async_trait::async_trait; +use serde::Serialize; + +// We cannot use generic here. It's why there is this trait +// Any type will allow us to split (with downcast) auth context +pub trait GatewayAuthContext: Debug + Send + Sync { + fn as_any(&self) -> &dyn Any; + + fn user(&self) -> Option<&String>; + + fn security_context(&self) -> Option<&serde_json::Value>; +} + +pub type GatewayAuthContextRef = Arc; + +#[derive(Debug)] +pub struct GatewayAuthenticateResponse { + pub context: GatewayAuthContextRef, +} + +#[derive(Debug, Serialize)] +pub struct GatewayCheckAuthRequest { + pub protocol: String, +} + +#[derive(Debug)] +pub struct GatewayContextToApiScopesResponse { + pub scopes: Vec, +} + +#[async_trait] +pub trait GatewayAuthService: Send + Sync + Debug { + async fn authenticate( + &self, + req: GatewayCheckAuthRequest, + token: String, + ) -> Result; + + async fn context_to_api_scopes( + &self, + auth_context: &GatewayAuthContextRef, + ) -> Result; +} diff --git a/packages/cubejs-backend-native/src/gateway/handlers/stream.rs b/packages/cubejs-backend-native/src/gateway/handlers/stream.rs index b5b88d59e6d39..afe162350d4d4 100644 --- a/packages/cubejs-backend-native/src/gateway/handlers/stream.rs +++ b/packages/cubejs-backend-native/src/gateway/handlers/stream.rs @@ -1,7 +1,10 @@ -use crate::gateway::ApiGatewayState; +use crate::gateway::auth_middleware::AuthExtension; +use crate::gateway::http_error::HttpError; +use crate::gateway::state::ApiGatewayStateRef; use axum::extract::State; use axum::http::StatusCode; -use axum::Json; +use axum::response::IntoResponse; +use axum::{Extension, Json}; use serde::Serialize; #[derive(Serialize)] @@ -10,12 +13,17 @@ pub struct HandlerResponse { } pub async fn stream_handler_v2( - State(_state): State, -) -> (StatusCode, Json) { - ( + State(gateway_state): State, + Extension(auth): Extension, +) -> Result { + gateway_state + .assert_api_scope(auth.auth_context(), "data") + .await?; + + Ok(( StatusCode::NOT_IMPLEMENTED, Json(HandlerResponse { - message: "Not implemented".to_string(), + message: "/v2/stream is not implemented".to_string(), }), - ) + )) } diff --git a/packages/cubejs-backend-native/src/gateway/http_error.rs b/packages/cubejs-backend-native/src/gateway/http_error.rs new file mode 100644 index 0000000000000..d45f33630f118 --- /dev/null +++ b/packages/cubejs-backend-native/src/gateway/http_error.rs @@ -0,0 +1,84 @@ +use axum::http::header::ToStrError; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use cubesql::{CubeError, CubeErrorCauseType}; +use serde::Serialize; + +// Re-export axum::http::StatusCode as public API +pub type HttpStatusCode = axum::http::StatusCode; + +pub enum HttpErrorCode { + StatusCode(HttpStatusCode), +} + +pub struct HttpError { + code: HttpErrorCode, + message: String, +} + +impl HttpError { + pub fn forbidden(message: String) -> HttpError { + Self { + code: HttpErrorCode::StatusCode(HttpStatusCode::FORBIDDEN), + message, + } + } + + pub fn unauthorized(message: String) -> HttpError { + Self { + code: HttpErrorCode::StatusCode(HttpStatusCode::UNAUTHORIZED), + message, + } + } + + pub fn status_code(&self) -> HttpStatusCode { + match self.code { + HttpErrorCode::StatusCode(code) => code, + } + } + + /// CubeError may contain unsafe error message, when it's internal error + /// We cannot map this error to HTTP status code, that's why we pass it as argument + pub fn from_user_with_status_code(error: CubeError, code: HttpErrorCode) -> Self { + Self { + code: match error.cause { + CubeErrorCauseType::User(_) => code, + CubeErrorCauseType::Internal(_) => { + HttpErrorCode::StatusCode(axum::http::StatusCode::INTERNAL_SERVER_ERROR) + } + }, + message: match error.cause { + CubeErrorCauseType::User(_) => error.message, + CubeErrorCauseType::Internal(_) => "Internal Server Error".to_string(), + }, + } + } +} + +#[derive(Serialize)] +pub struct HttpErrorResponse { + message: String, +} + +impl IntoResponse for HttpError { + fn into_response(self) -> Response { + let status_code = self.status_code(); + + ( + status_code, + Json(HttpErrorResponse { + message: self.message, + }), + ) + .into_response() + } +} + +impl From for HttpError { + fn from(value: ToStrError) -> Self { + HttpError { + code: HttpErrorCode::StatusCode(axum::http::StatusCode::INTERNAL_SERVER_ERROR), + message: value.to_string(), + } + } +} diff --git a/packages/cubejs-backend-native/src/gateway/mod.rs b/packages/cubejs-backend-native/src/gateway/mod.rs index 99970913c6533..149afca40aea6 100644 --- a/packages/cubejs-backend-native/src/gateway/mod.rs +++ b/packages/cubejs-backend-native/src/gateway/mod.rs @@ -1,8 +1,16 @@ +pub mod auth_middleware; +pub mod auth_service; pub mod handlers; +pub mod http_error; pub mod router; pub mod server; pub mod state; -pub use router::ApiGatewayRouterBuilder; +pub use auth_middleware::gateway_auth_middleware; +pub use auth_service::{ + GatewayAuthContext, GatewayAuthContextRef, GatewayAuthService, GatewayAuthenticateResponse, + GatewayCheckAuthRequest, +}; +pub use router::{ApiGatewayRouterBuilder, RApiGatewayRouter}; pub use server::{ApiGatewayServer, ApiGatewayServerImpl}; -pub use state::ApiGatewayState; +pub use state::{ApiGatewayState, ApiGatewayStateRef}; diff --git a/packages/cubejs-backend-native/src/gateway/router.rs b/packages/cubejs-backend-native/src/gateway/router.rs index 1d29749378014..16c1c36849770 100644 --- a/packages/cubejs-backend-native/src/gateway/router.rs +++ b/packages/cubejs-backend-native/src/gateway/router.rs @@ -1,34 +1,37 @@ +use crate::gateway::gateway_auth_middleware; use crate::gateway::handlers::stream_handler_v2; -use crate::gateway::ApiGatewayState; +use crate::gateway::state::ApiGatewayStateRef; use axum::routing::{get, MethodRouter}; use axum::Router; +pub type RApiGatewayRouter = Router; + #[derive(Debug, Clone)] pub struct ApiGatewayRouterBuilder { - router: Router, -} - -impl Default for ApiGatewayRouterBuilder { - fn default() -> Self { - Self::new() - } + router: RApiGatewayRouter, } impl ApiGatewayRouterBuilder { - pub fn new() -> Self { + pub fn new(state: ApiGatewayStateRef) -> Self { let router = Router::new(); - let router = router.route("/v2/stream", get(stream_handler_v2)); + let router = router.route( + "/v2/stream", + get(stream_handler_v2).layer(axum::middleware::from_fn_with_state( + state, + gateway_auth_middleware, + )), + ); Self { router } } - pub fn route(self, path: &str, method_router: MethodRouter) -> Self { + pub fn route(self, path: &str, method_router: MethodRouter) -> Self { Self { router: self.router.route(path, method_router), } } - pub fn build(self) -> Router { + pub fn build(self) -> RApiGatewayRouter { self.router } } diff --git a/packages/cubejs-backend-native/src/gateway/server.rs b/packages/cubejs-backend-native/src/gateway/server.rs index c00fbf9a301ef..f551f8ce1f301 100644 --- a/packages/cubejs-backend-native/src/gateway/server.rs +++ b/packages/cubejs-backend-native/src/gateway/server.rs @@ -1,6 +1,6 @@ -use crate::gateway::{ApiGatewayRouterBuilder, ApiGatewayState}; +use crate::gateway::state::ApiGatewayStateRef; +use crate::gateway::ApiGatewayRouterBuilder; use async_trait::async_trait; -use cubesql::config::injection::Injector; use cubesql::config::processing_loop::{ProcessingLoop, ShutdownMode}; use cubesql::CubeError; use std::sync::Arc; @@ -34,13 +34,11 @@ impl ApiGatewayServerImpl { pub fn new( router_builder: ApiGatewayRouterBuilder, address: String, - injector: Arc, + state: ApiGatewayStateRef, ) -> Arc { let (close_socket_tx, close_socket_rx) = watch::channel(false); - let router = router_builder - .build() - .with_state(ApiGatewayState::new(injector)); + let router = router_builder.build().with_state(state); Arc::new(Self { inner_factory_state: Mutex::new(Some(InnerFactoryState { @@ -87,7 +85,13 @@ impl ProcessingLoop for ApiGatewayServerImpl { async fn stop_processing(&self, _mode: ShutdownMode) -> Result<(), CubeError> { // ShutdownMode was added for Postgres protocol and its use here has not yet been considered. - self.close_socket_tx.send(true)?; + self.close_socket_tx.send(true).map_err(|err| { + CubeError::internal(format!( + "Failed to send close signal to ApiGatewayServer: {}", + err + )) + })?; + Ok(()) } } diff --git a/packages/cubejs-backend-native/src/gateway/state.rs b/packages/cubejs-backend-native/src/gateway/state.rs index ebf5f827ee28d..9651768b0f53d 100644 --- a/packages/cubejs-backend-native/src/gateway/state.rs +++ b/packages/cubejs-backend-native/src/gateway/state.rs @@ -1,11 +1,14 @@ +use crate::gateway::http_error::{HttpError, HttpErrorCode, HttpStatusCode}; +use crate::gateway::{GatewayAuthContextRef, GatewayAuthService}; use cubesql::config::injection::Injector; use std::sync::Arc; -#[derive(Clone)] pub struct ApiGatewayState { injector: Arc, } +pub type ApiGatewayStateRef = Arc; + impl ApiGatewayState { pub fn new(injector: Arc) -> Self { Self { injector } @@ -14,4 +17,35 @@ impl ApiGatewayState { pub fn injector_ref(&self) -> &Arc { &self.injector } + + pub async fn assert_api_scope( + &self, + gateway_auth_context: &GatewayAuthContextRef, + api_scope: &str, + ) -> Result<(), HttpError> { + let auth_service = self + .injector_ref() + .get_service_typed::() + .await; + + let api_scopes_res = auth_service + .context_to_api_scopes(gateway_auth_context) + .await + .map_err(|err| { + log::error!("Error getting API scopes: {}", err); + + HttpError::from_user_with_status_code( + err, + HttpErrorCode::StatusCode(HttpStatusCode::INTERNAL_SERVER_ERROR), + ) + })?; + if !api_scopes_res.scopes.contains(&api_scope.to_string()) { + Err(HttpError::forbidden(format!( + "API scope is missing: {}", + api_scope + ))) + } else { + Ok(()) + } + } } diff --git a/packages/cubejs-backend-native/src/lib.rs b/packages/cubejs-backend-native/src/lib.rs index eb475af8857b6..6f16a6ec3777d 100644 --- a/packages/cubejs-backend-native/src/lib.rs +++ b/packages/cubejs-backend-native/src/lib.rs @@ -23,11 +23,9 @@ pub mod transport; pub mod utils; use crate::config::NodeConfigurationImpl; -use cubesql::telemetry::{LocalReporter, ReportingLogger}; use cubesql::CubeError; use neon::prelude::*; use once_cell::sync::OnceCell; -use simple_logger::SimpleLogger; use tokio::runtime::{Builder, Runtime}; pub fn tokio_runtime_node<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> { @@ -48,29 +46,13 @@ pub fn tokio_runtime() -> Result<&'static Runtime, CubeError> { }) } -pub fn create_logger(log_level: log::Level) -> SimpleLogger { - SimpleLogger::new() - .with_level(log::Level::Error.to_level_filter()) - .with_module_level("cubesql", log_level.to_level_filter()) - .with_module_level("cubejs_native", log_level.to_level_filter()) - .with_module_level("datafusion", log::Level::Warn.to_level_filter()) - .with_module_level("pg_srv", log::Level::Warn.to_level_filter()) -} - #[cfg(feature = "neon-entrypoint")] #[neon::main] fn main(cx: ModuleContext) -> NeonResult<()> { // We use log_rerouter to swap logger, because we init logger from js side in api-gateway log_reroute::init().unwrap(); - let logger = Box::new(create_logger(log::Level::Error)); - log_reroute::reroute_boxed(logger); - - ReportingLogger::init( - Box::new(LocalReporter::new()), - log::Level::Error.to_level_filter(), - ) - .unwrap(); + node_export::setup_local_logger(log::Level::Error); node_export::register_module_exports::(cx)?; diff --git a/packages/cubejs-backend-native/src/logger.rs b/packages/cubejs-backend-native/src/logger.rs index 5fed32bc2ef9c..0c8b0200af0d0 100644 --- a/packages/cubejs-backend-native/src/logger.rs +++ b/packages/cubejs-backend-native/src/logger.rs @@ -11,14 +11,14 @@ use crate::channel::call_js_with_channel_as_callback; #[derive(Debug)] pub struct NodeBridgeLogger { channel: Arc, - on_track: Arc>, + on_track: Option>>, } impl NodeBridgeLogger { pub fn new(channel: Channel, on_track: Root) -> Self { Self { channel: Arc::new(channel), - on_track: Arc::new(on_track), + on_track: Some(Arc::new(on_track)), } } } @@ -35,12 +35,42 @@ impl LogReporter for NodeBridgeLogger { let extra = serde_json::to_string(&EventBox { event: props }).unwrap(); let channel = self.channel.clone(); - let on_track = self.on_track.clone(); + // Safety: Safe, because on_track take is used only for dropping + let on_track = self + .on_track + .as_ref() + .expect( + "Unable to unwrap on_track to log event for NodeBridgeLogger. Logger was dropped?", + ) + .clone(); + // TODO: Move to spawning loops spawn(async move { log(channel, on_track, Some(extra)).await }); } } +impl Drop for NodeBridgeLogger { + fn drop(&mut self) { + let channel = self.channel.clone(); + let on_track = self.on_track.take().expect( + "Unable to take on_track while dropping NodeBridgeLogger, it was already taken", + ); + + channel.send(move |mut cx| { + let _ = match Arc::try_unwrap(on_track) { + Ok(v) => { + v.into_inner(&mut cx) + }, + Err(_) => { + log::error!("Unable to drop sql generator: reference is copied somewhere else. Potential memory leak"); + return Ok(()); + }, + }; + Ok(()) + }); + } +} + async fn log(channel: Arc, on_track: Arc>, extra: Option) { let _ = call_js_with_channel_as_callback::(channel, on_track, extra).await; } diff --git a/packages/cubejs-backend-native/src/node_export.rs b/packages/cubejs-backend-native/src/node_export.rs index cd8a0b26ba0f9..b9744fa3784f3 100644 --- a/packages/cubejs-backend-native/src/node_export.rs +++ b/packages/cubejs-backend-native/src/node_export.rs @@ -8,7 +8,7 @@ use serde_json::Map; use tokio::sync::Semaphore; use uuid::Uuid; -use crate::auth::{NativeAuthContext, NodeBridgeAuthService}; +use crate::auth::{NativeSQLAuthContext, NodeBridgeAuthService, NodeBridgeAuthServiceOptions}; use crate::channel::call_js_fn; use crate::config::{NodeConfiguration, NodeConfigurationFactoryOptions, NodeCubeServices}; use crate::cross::CLRepr; @@ -30,9 +30,11 @@ use cubesqlplanner::planner::base_query::BaseQuery; use std::rc::Rc; use std::sync::Arc; +use cubesql::telemetry::LocalReporter; use cubesql::{telemetry::ReportingLogger, CubeError}; - use neon::prelude::*; +use neon::result::Throw; +use simple_logger::SimpleLogger; pub(crate) struct SQLInterface { pub(crate) services: Arc, @@ -46,29 +48,35 @@ impl SQLInterface { } } +fn get_function_from_options( + options: Handle, + name: &str, + cx: &mut FunctionContext, +) -> Result, Throw> { + let fun = options.get_opt::(cx, name)?; + if let Some(fun) = fun { + Ok(fun.downcast_or_throw::(cx)?.root(cx)) + } else { + cx.throw_error(format!( + "{} is required, must be passed as option in registerInterface", + name + )) + } +} + fn register_interface(mut cx: FunctionContext) -> JsResult { let options = cx.argument::(0)?; - let check_sql_auth = options - .get::(&mut cx, "checkSqlAuth")? - .root(&mut cx); - let transport_sql_api_load = options - .get::(&mut cx, "sqlApiLoad")? - .root(&mut cx); - let transport_sql = options - .get::(&mut cx, "sql")? - .root(&mut cx); - let transport_meta = options - .get::(&mut cx, "meta")? - .root(&mut cx); - let transport_log_load_event = options - .get::(&mut cx, "logLoadEvent")? - .root(&mut cx); - let transport_sql_generator = options - .get::(&mut cx, "sqlGenerators")? - .root(&mut cx); - let transport_can_switch_user_for_session = options - .get::(&mut cx, "canSwitchUserForSession")? - .root(&mut cx); + + let context_to_api_scopes = get_function_from_options(options, "contextToApiScopes", &mut cx)?; + let check_auth = get_function_from_options(options, "checkAuth", &mut cx)?; + let check_sql_auth = get_function_from_options(options, "checkSqlAuth", &mut cx)?; + let transport_sql_api_load = get_function_from_options(options, "sqlApiLoad", &mut cx)?; + let transport_sql = get_function_from_options(options, "sql", &mut cx)?; + let transport_meta = get_function_from_options(options, "meta", &mut cx)?; + let transport_log_load_event = get_function_from_options(options, "logLoadEvent", &mut cx)?; + let transport_sql_generator = get_function_from_options(options, "sqlGenerators", &mut cx)?; + let transport_can_switch_user_for_session = + get_function_from_options(options, "canSwitchUserForSession", &mut cx)?; let pg_port_handle = options.get_value(&mut cx, "pgPort")?; let pg_port = if pg_port_handle.is_a::(&mut cx) { @@ -101,7 +109,14 @@ fn register_interface(mut cx: FunctionContext) -> JsResult transport_sql_generator, transport_can_switch_user_for_session, ); - let auth_service = NodeBridgeAuthService::new(cx.channel(), check_sql_auth); + let auth_service = NodeBridgeAuthService::new( + cx.channel(), + NodeBridgeAuthServiceOptions { + check_auth, + check_sql_auth, + context_to_api_scopes, + }, + ); std::thread::spawn(move || { let config = C::new(NodeConfigurationFactoryOptions { @@ -179,7 +194,7 @@ const CHUNK_DELIM: &str = "\n"; async fn handle_sql_query( services: Arc, - native_auth_ctx: Arc, + native_auth_ctx: Arc, channel: Arc, stream_methods: WritableStreamMethods, sql_query: &str, @@ -411,7 +426,7 @@ fn exec_sql(mut cx: FunctionContext) -> JsResult { let channel = Arc::new(cx.channel()); let node_stream_arc = Arc::new(node_stream_root); - let native_auth_ctx = Arc::new(NativeAuthContext { + let native_auth_ctx = Arc::new(NativeSQLAuthContext { user: Some(String::from("unknown")), superuser: false, security_context, @@ -483,16 +498,13 @@ fn is_fallback_build(mut cx: FunctionContext) -> JsResult { Ok(JsBoolean::new(&mut cx, true)) } -pub fn setup_logger(mut cx: FunctionContext) -> JsResult { - let options = cx.argument::(0)?; - let cube_logger = options - .get::(&mut cx, "logger")? - .root(&mut cx); - - let log_level_handle = options.get_value(&mut cx, "logLevel")?; - let log_level = if log_level_handle.is_a::(&mut cx) { - let value = log_level_handle.downcast_or_throw::(&mut cx)?; - let log_level = match value.value(&mut cx).as_str() { +fn get_log_level_from_variable( + log_level_handle: Handle, + cx: &mut FunctionContext, +) -> NeonResult { + if log_level_handle.is_a::(cx) { + let value = log_level_handle.downcast_or_throw::(cx)?; + let log_level = match value.value(cx).as_str() { "error" => log::Level::Error, "warn" => log::Level::Warn, "info" => log::Level::Info, @@ -500,12 +512,23 @@ pub fn setup_logger(mut cx: FunctionContext) -> JsResult { "trace" => log::Level::Trace, x => cx.throw_error(format!("Unrecognized log level: {}", x))?, }; - log_level + + Ok(log_level) } else { - log::Level::Trace - }; + Ok(log::Level::Trace) + } +} + +pub fn setup_logger(mut cx: FunctionContext) -> JsResult { + let options = cx.argument::(0)?; + let cube_logger = options + .get::(&mut cx, "logger")? + .root(&mut cx); - let logger = crate::create_logger(log_level); + let log_level_handle = options.get_value(&mut cx, "logLevel")?; + let log_level = get_log_level_from_variable(log_level_handle, &mut cx)?; + + let logger = create_logger(log_level); log_reroute::reroute_boxed(Box::new(logger)); ReportingLogger::init( @@ -517,6 +540,37 @@ pub fn setup_logger(mut cx: FunctionContext) -> JsResult { Ok(cx.undefined()) } +pub fn create_logger(log_level: log::Level) -> SimpleLogger { + SimpleLogger::new() + .with_level(log::Level::Error.to_level_filter()) + .with_module_level("cubesql", log_level.to_level_filter()) + .with_module_level("cubejs_native", log_level.to_level_filter()) + .with_module_level("datafusion", log::Level::Warn.to_level_filter()) + .with_module_level("pg_srv", log::Level::Warn.to_level_filter()) +} + +pub fn setup_local_logger(log_level: log::Level) { + let logger = create_logger(log_level); + log_reroute::reroute_boxed(Box::new(logger)); + + ReportingLogger::init( + Box::new(LocalReporter::new()), + log::Level::Error.to_level_filter(), + ) + .unwrap(); +} + +pub fn reset_logger(mut cx: FunctionContext) -> JsResult { + let options = cx.argument::(0)?; + + let log_level_handle = options.get_value(&mut cx, "logLevel")?; + let log_level = get_log_level_from_variable(log_level_handle, &mut cx)?; + + setup_local_logger(log_level); + + Ok(cx.undefined()) +} + //============ sql planner =================== fn build_sql_and_params(cx: FunctionContext) -> JsResult { @@ -556,6 +610,7 @@ pub fn register_module_exports( mut cx: ModuleContext, ) -> NeonResult<()> { cx.export_function("setupLogger", setup_logger)?; + cx.export_function("resetLogger", reset_logger)?; cx.export_function("registerInterface", register_interface::)?; cx.export_function("shutdownInterface", shutdown_interface)?; cx.export_function("execSql", exec_sql)?; diff --git a/packages/cubejs-backend-native/src/sql4sql.rs b/packages/cubejs-backend-native/src/sql4sql.rs index bfeef78b49153..71c71df4bdd19 100644 --- a/packages/cubejs-backend-native/src/sql4sql.rs +++ b/packages/cubejs-backend-native/src/sql4sql.rs @@ -12,7 +12,7 @@ use cubesql::sql::{Session, CUBESQL_PENALIZE_POST_PROCESSING_VAR}; use cubesql::transport::MetaContext; use cubesql::CubeError; -use crate::auth::NativeAuthContext; +use crate::auth::NativeSQLAuthContext; use crate::config::NodeCubeServices; use crate::cubesql_utils::with_session; use crate::tokio_runtime_node; @@ -157,7 +157,7 @@ async fn get_sql( async fn handle_sql4sql_query( services: Arc, - native_auth_ctx: Arc, + native_auth_ctx: Arc, sql_query: &str, disable_post_processing: bool, ) -> Result { @@ -205,7 +205,7 @@ pub fn sql4sql(mut cx: FunctionContext) -> JsResult { let channel = cx.channel(); - let native_auth_ctx = Arc::new(NativeAuthContext { + let native_auth_ctx = Arc::new(NativeSQLAuthContext { user: Some(String::from("unknown")), superuser: false, security_context, diff --git a/packages/cubejs-backend-native/src/transport.rs b/packages/cubejs-backend-native/src/transport.rs index a54aa3f46bda3..5f8f942868780 100644 --- a/packages/cubejs-backend-native/src/transport.rs +++ b/packages/cubejs-backend-native/src/transport.rs @@ -3,7 +3,7 @@ use neon::prelude::*; use std::collections::HashMap; use std::fmt::Display; -use crate::auth::NativeAuthContext; +use crate::auth::NativeSQLAuthContext; use crate::channel::{call_raw_js_with_channel_as_callback, NodeSqlGenerator, ValueFromJs}; use crate::node_obj_serializer::NodeObjSerializer; use crate::orchestrator::ResultWrapper; @@ -116,7 +116,7 @@ impl TransportService for NodeBridgeTransport { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let request_id = Uuid::new_v4().to_string(); @@ -217,7 +217,7 @@ impl TransportService for NodeBridgeTransport { async fn compiler_id(&self, ctx: AuthContextRef) -> Result { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let request_id = Uuid::new_v4().to_string(); @@ -263,7 +263,7 @@ impl TransportService for NodeBridgeTransport { ) -> Result { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let request_id = span_id @@ -343,7 +343,7 @@ impl TransportService for NodeBridgeTransport { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let request_id = span_id @@ -508,7 +508,7 @@ impl TransportService for NodeBridgeTransport { req_seq_id += 1; let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let extra = serde_json::to_string(&LoadRequest { @@ -555,7 +555,7 @@ impl TransportService for NodeBridgeTransport { ) -> Result { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let res = call_raw_js_with_channel_as_callback( @@ -594,7 +594,7 @@ impl TransportService for NodeBridgeTransport { ) -> Result<(), CubeError> { let native_auth = ctx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to cast AuthContext to NativeAuthContext"); let request_id = span_id diff --git a/packages/cubejs-backend-native/test/sql.test.ts b/packages/cubejs-backend-native/test/sql.test.ts index 70dad6b96f6a3..3a94f8713509b 100644 --- a/packages/cubejs-backend-native/test/sql.test.ts +++ b/packages/cubejs-backend-native/test/sql.test.ts @@ -103,6 +103,22 @@ function interfaceMethods() { dataSourceToSqlGenerator: {}, }; }), + contextToApiScopes: jest.fn(async ({ request, token }) => { + console.log('[js] contextToApiScopes', { + request, + token, + }); + + return ['data', 'meta', 'graphql']; + }), + checkAuth: jest.fn(async ({ request, token }) => { + console.log('[js] checkAuth', { + request, + token, + }); + + throw new Error('checkAuth is not implemented'); + }), checkSqlAuth: jest.fn(async ({ request, user }) => { console.log('[js] checkSqlAuth', { request, diff --git a/rust/cubesql/cubesql/src/sql/auth_service.rs b/rust/cubesql/cubesql/src/sql/auth_service.rs index 1079c9ec22fd1..f29002f862d24 100644 --- a/rust/cubesql/cubesql/src/sql/auth_service.rs +++ b/rust/cubesql/cubesql/src/sql/auth_service.rs @@ -1,13 +1,17 @@ use std::{any::Any, env, fmt::Debug, sync::Arc}; -use async_trait::async_trait; - use crate::CubeError; +use async_trait::async_trait; +use serde_json::Value; // We cannot use generic here. It's why there is this trait // Any type will allow us to split (with downcast) auth context into HTTP (standalone) or Native pub trait AuthContext: Debug + Send + Sync { fn as_any(&self) -> &dyn Any; + + fn user(&self) -> Option<&String>; + + fn security_context(&self) -> Option<&serde_json::Value>; } pub type AuthContextRef = Arc; @@ -22,6 +26,14 @@ impl AuthContext for HttpAuthContext { fn as_any(&self) -> &dyn Any { self } + + fn user(&self) -> Option<&String> { + None + } + + fn security_context(&self) -> Option<&Value> { + None + } } #[derive(Debug)]