Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions graphql/server/src/middleware/__tests__/upload.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ describe('createUploadAuthenticateMiddleware', () => {
const res = makeRes();
const next = makeNext();

// typed rls_settings query returns no rows (table may not exist yet)
rootPool.query.mockResolvedValueOnce({ rows: [] });
// legacy api_modules fallback
rootPool.query.mockResolvedValueOnce({
rows: [
{
Expand Down Expand Up @@ -282,6 +285,9 @@ describe('createUploadAuthenticateMiddleware', () => {
const res = makeRes();
const next = makeNext();

// typed rls_settings query returns no rows (table may not exist yet)
rootPool.query.mockResolvedValueOnce({ rows: [] });
// legacy api_modules fallback
rootPool.query.mockResolvedValueOnce({
rows: [
{
Expand Down Expand Up @@ -330,7 +336,13 @@ describe('createUploadAuthenticateMiddleware', () => {
const res = makeRes();
const next = makeNext();

// typed rls_settings query returns no rows
rootPool.query.mockResolvedValueOnce({ rows: [] });
// legacy api_modules by database_id returns no rows
rootPool.query.mockResolvedValueOnce({ rows: [] });
// typed rls_settings by dbname returns no rows
rootPool.query.mockResolvedValueOnce({ rows: [] });
// legacy api_modules by dbname returns no rows
rootPool.query.mockResolvedValueOnce({ rows: [] });

await middleware(req, res, next);
Expand Down
68 changes: 62 additions & 6 deletions graphql/server/src/middleware/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ const RLS_MODULE_SQL = `
LIMIT 1
`;

const RLS_SETTINGS_SQL = `
SELECT
auth_schema.schema_name AS authenticate_schema,
role_schema.schema_name AS role_schema,
auth_fn.name AS authenticate,
auth_strict_fn.name AS authenticate_strict,
role_fn.name AS current_role,
role_id_fn.name AS current_role_id,
ua_fn.name AS current_user_agent,
ip_fn.name AS current_ip_address
FROM services_public.rls_settings rs
LEFT JOIN metaschema_public.schema auth_schema ON rs.authenticate_schema_id = auth_schema.id
LEFT JOIN metaschema_public.schema role_schema ON rs.role_schema_id = role_schema.id
LEFT JOIN metaschema_public.function auth_fn ON rs.authenticate_function_id = auth_fn.id
LEFT JOIN metaschema_public.function auth_strict_fn ON rs.authenticate_strict_function_id = auth_strict_fn.id
LEFT JOIN metaschema_public.function role_fn ON rs.current_role_function_id = role_fn.id
LEFT JOIN metaschema_public.function role_id_fn ON rs.current_role_id_function_id = role_id_fn.id
LEFT JOIN metaschema_public.function ua_fn ON rs.current_user_agent_function_id = ua_fn.id
LEFT JOIN metaschema_public.function ip_fn ON rs.current_ip_address_function_id = ip_fn.id
WHERE rs.database_id = $1
LIMIT 1
`;

/**
* Discover auth settings table location via public metaschema tables.
* Joins sessions_module with metaschema_public.schema to resolve
Expand Down Expand Up @@ -249,6 +272,24 @@ const toRlsModule = (row: RlsModuleRow | null): RlsModule | undefined => {
};
};

const toRlsModuleFromSettings = (row: RlsModuleData | null): RlsModule | undefined => {
if (!row) return undefined;
return {
authenticate: row.authenticate,
authenticateStrict: row.authenticate_strict,
privateSchema: {
schemaName: row.authenticate_schema,
},
publicSchema: {
schemaName: row.role_schema,
},
currentRole: row.current_role,
currentRoleId: row.current_role_id,
currentIpAddress: row.current_ip_address,
currentUserAgent: row.current_user_agent,
};
};

const toAuthSettings = (row: AuthSettingsRow | null): AuthSettings | undefined => {
if (!row) return undefined;
return {
Expand All @@ -263,14 +304,14 @@ const toAuthSettings = (row: AuthSettingsRow | null): AuthSettings | undefined =
};
};

const toApiStructure = (row: ApiRow, opts: ApiOptions, rlsModuleRow?: RlsModuleRow | null, authSettingsRow?: AuthSettingsRow | null): ApiStructure => ({
const toApiStructure = (row: ApiRow, opts: ApiOptions, rlsModule?: RlsModule, authSettingsRow?: AuthSettingsRow | null): ApiStructure => ({
apiId: row.api_id,
dbname: row.dbname || opts.pg?.database || '',
anonRole: row.anon_role || 'anon',
roleName: row.role_name || 'authenticated',
schema: row.schemas || [],
apiModules: [],
rlsModule: toRlsModule(rlsModuleRow ?? null),
rlsModule,
domains: [],
databaseId: row.database_id,
isPublic: row.is_public,
Expand Down Expand Up @@ -329,9 +370,24 @@ const queryApiList = async (pool: Pool, isPublic: boolean): Promise<ApiListRow[]
return result.rows;
};

const queryRlsModule = async (pool: Pool, apiId: string): Promise<RlsModuleRow | null> => {
const queryRlsSettings = async (pool: Pool, databaseId: string): Promise<RlsModule | undefined> => {
try {
const result = await pool.query<RlsModuleData>(RLS_SETTINGS_SQL, [databaseId]);
return toRlsModuleFromSettings(result.rows[0] ?? null);
} catch {
return undefined;
}
};

const queryRlsModuleLegacy = async (pool: Pool, apiId: string): Promise<RlsModule | undefined> => {
const result = await pool.query<RlsModuleRow>(RLS_MODULE_SQL, [apiId]);
return result.rows[0] ?? null;
return toRlsModule(result.rows[0] ?? null);
};

const queryRlsModule = async (pool: Pool, databaseId: string, apiId: string): Promise<RlsModule | undefined> => {
const fromSettings = await queryRlsSettings(pool, databaseId);
if (fromSettings) return fromSettings;
return queryRlsModuleLegacy(pool, apiId);
};

/**
Expand Down Expand Up @@ -423,7 +479,7 @@ const resolveApiNameHeader = async (ctx: ResolveContext): Promise<ApiStructure |
return null;
}

const rlsModule = await queryRlsModule(pool, row.api_id);
const rlsModule = await queryRlsModule(pool, row.database_id, row.api_id);
const authSettings = await queryAuthSettings(opts, row.dbname);
log.debug(`[api-name-lookup] resolved schemas: [${row.schemas?.join(', ')}], rlsModule: ${rlsModule ? 'found' : 'none'}, authSettings: ${authSettings ? 'found' : 'none'}`);
return toApiStructure(row, opts, rlsModule, authSettings);
Expand All @@ -449,7 +505,7 @@ const resolveDomainLookup = async (ctx: ResolveContext): Promise<ApiStructure |
return null;
}

const rlsModule = await queryRlsModule(pool, row.api_id);
const rlsModule = await queryRlsModule(pool, row.database_id, row.api_id);
const authSettings = await queryAuthSettings(opts, row.dbname);
log.debug(`[domain-lookup] resolved schemas: [${row.schemas?.join(', ')}], rlsModule: ${rlsModule ? 'found' : 'none'}, authSettings: ${authSettings ? 'found' : 'none'}`);
return toApiStructure(row, opts, rlsModule, authSettings);
Expand Down
83 changes: 83 additions & 0 deletions graphql/server/src/middleware/upload.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,53 @@ const RLS_MODULE_BY_DBNAME_SQL = `
LIMIT 1
`;

const RLS_SETTINGS_BY_DATABASE_ID_SQL = `
SELECT
auth_schema.schema_name AS authenticate_schema,
role_schema.schema_name AS role_schema,
auth_fn.name AS authenticate,
auth_strict_fn.name AS authenticate_strict,
role_fn.name AS current_role,
role_id_fn.name AS current_role_id,
ua_fn.name AS current_user_agent,
ip_fn.name AS current_ip_address
FROM services_public.rls_settings rs
LEFT JOIN metaschema_public.schema auth_schema ON rs.authenticate_schema_id = auth_schema.id
LEFT JOIN metaschema_public.schema role_schema ON rs.role_schema_id = role_schema.id
LEFT JOIN metaschema_public.function auth_fn ON rs.authenticate_function_id = auth_fn.id
LEFT JOIN metaschema_public.function auth_strict_fn ON rs.authenticate_strict_function_id = auth_strict_fn.id
LEFT JOIN metaschema_public.function role_fn ON rs.current_role_function_id = role_fn.id
LEFT JOIN metaschema_public.function role_id_fn ON rs.current_role_id_function_id = role_id_fn.id
LEFT JOIN metaschema_public.function ua_fn ON rs.current_user_agent_function_id = ua_fn.id
LEFT JOIN metaschema_public.function ip_fn ON rs.current_ip_address_function_id = ip_fn.id
WHERE rs.database_id = $1
LIMIT 1
`;

const RLS_SETTINGS_BY_DBNAME_SQL = `
SELECT
auth_schema.schema_name AS authenticate_schema,
role_schema.schema_name AS role_schema,
auth_fn.name AS authenticate,
auth_strict_fn.name AS authenticate_strict,
role_fn.name AS current_role,
role_id_fn.name AS current_role_id,
ua_fn.name AS current_user_agent,
ip_fn.name AS current_ip_address
FROM services_public.rls_settings rs
JOIN services_public.apis a ON rs.database_id = a.database_id
LEFT JOIN metaschema_public.schema auth_schema ON rs.authenticate_schema_id = auth_schema.id
LEFT JOIN metaschema_public.schema role_schema ON rs.role_schema_id = role_schema.id
LEFT JOIN metaschema_public.function auth_fn ON rs.authenticate_function_id = auth_fn.id
LEFT JOIN metaschema_public.function auth_strict_fn ON rs.authenticate_strict_function_id = auth_strict_fn.id
LEFT JOIN metaschema_public.function role_fn ON rs.current_role_function_id = role_fn.id
LEFT JOIN metaschema_public.function role_id_fn ON rs.current_role_id_function_id = role_id_fn.id
LEFT JOIN metaschema_public.function ua_fn ON rs.current_user_agent_function_id = ua_fn.id
LEFT JOIN metaschema_public.function ip_fn ON rs.current_ip_address_function_id = ip_fn.id
WHERE a.dbname = $1
LIMIT 1
`;

interface RlsModuleData {
authenticate: string;
authenticate_strict: string;
Expand Down Expand Up @@ -111,6 +158,20 @@ const toRlsModule = (row: RlsModuleRow | null): RlsModule | undefined => {
};
};

const toRlsModuleFromSettings = (row: RlsModuleData | null): RlsModule | undefined => {
if (!row) return undefined;
return {
authenticate: row.authenticate,
authenticateStrict: row.authenticate_strict,
privateSchema: { schemaName: row.authenticate_schema },
publicSchema: { schemaName: row.role_schema },
currentRole: row.current_role,
currentRoleId: row.current_role_id,
currentIpAddress: row.current_ip_address,
currentUserAgent: row.current_user_agent,
};
};

const getBearerToken = (authorization?: string): string | null => {
if (!authorization) return null;
const [authType, authToken] = authorization.split(' ');
Expand All @@ -120,7 +181,27 @@ const getBearerToken = (authorization?: string): string | null => {
return authToken;
};

const queryRlsSettingsByDatabaseId = async (pool: Pool, databaseId: string): Promise<RlsModule | undefined> => {
try {
const result = await pool.query<RlsModuleData>(RLS_SETTINGS_BY_DATABASE_ID_SQL, [databaseId]);
return toRlsModuleFromSettings(result.rows[0] ?? null);
} catch {
return undefined;
}
};

const queryRlsSettingsByDbname = async (pool: Pool, dbname: string): Promise<RlsModule | undefined> => {
try {
const result = await pool.query<RlsModuleData>(RLS_SETTINGS_BY_DBNAME_SQL, [dbname]);
return toRlsModuleFromSettings(result.rows[0] ?? null);
} catch {
return undefined;
}
};

const queryRlsModuleByDatabaseId = async (pool: Pool, databaseId: string): Promise<RlsModule | undefined> => {
const fromSettings = await queryRlsSettingsByDatabaseId(pool, databaseId);
if (fromSettings) return fromSettings;
const result = await pool.query<RlsModuleRow>(RLS_MODULE_BY_DATABASE_ID_SQL, [databaseId]);
return toRlsModule(result.rows[0] ?? null);
};
Expand All @@ -131,6 +212,8 @@ const queryRlsModuleByApiId = async (pool: Pool, apiId: string): Promise<RlsModu
};

const queryRlsModuleByDbname = async (pool: Pool, dbname: string): Promise<RlsModule | undefined> => {
const fromSettings = await queryRlsSettingsByDbname(pool, dbname);
if (fromSettings) return fromSettings;
const result = await pool.query<RlsModuleRow>(RLS_MODULE_BY_DBNAME_SQL, [dbname]);
return toRlsModule(result.rows[0] ?? null);
};
Expand Down
Loading