Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(oauth2-profiler): fix OAuth2 profiler params #562

Merged
merged 10 commits into from
Jan 23, 2024
41 changes: 41 additions & 0 deletions typegate/src/runtimes/typegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ export class TypeGateRuntime extends Runtime {
if (name === "argInfoByPath") {
return this.argInfoByPath;
}
if (name === "findListQueries") {
return this.findListQueries;
}

return async ({ _: { parent }, ...args }) => {
const resolver = parent[stage.props.node];
Expand Down Expand Up @@ -200,6 +203,39 @@ export class TypeGateRuntime extends Runtime {

return paths.map((path) => walkPath(tg!.tg, input, 0, path));
};

findListQueries: Resolver = ({ typegraph }) => {
const tg = this.typegate.register.get(typegraph);

const root = tg!.tg.type(0, Type.OBJECT).properties.query;
const exposed = tg!.tg.type(root, Type.OBJECT).properties;

return Object.entries(exposed).map(([name, idx]) => {
const func = tg!.tg.type(idx, Type.FUNCTION);
const input = tg!.tg.type(func.input, Type.OBJECT);
const inputs = input.properties;
const output = tg!.tg.type(func.output);
if (output.type != Type.LIST) {
return null;
}
const outputItem = tg!.tg.type(output.items);
if (outputItem.type != Type.OBJECT) {
return null;
}

return {
name,
inputs: Object.keys(inputs).map((name) => {
return {
name,
type: walkPath(tg!.tg, input, 0, [name]),
};
}),
output: walkPath(tg!.tg, output, 0, []),
outputItem: walkPath(tg!.tg, outputItem, 0, []),
};
}).filter((e) => e != null);
};
}

function resolveOptional(tg: TypeGraph, node: TypeNode) {
Expand All @@ -225,6 +261,7 @@ function collectObjectFields(
// first generate all possible paths

const paths = [] as Array<Array<string>>;
const traversed = new Set();

const collectAllPaths = (
parent: TypeNode,
Expand All @@ -233,6 +270,10 @@ function collectObjectFields(
const node = resolveOptional(tg, parent).node;

if (node.type == Type.OBJECT) {
if (traversed.has(node.title)) {
return;
}
traversed.add(node.title);
for (const [keyName, fieldIdx] of Object.entries(node.properties)) {
collectAllPaths(tg.type(fieldIdx), [...currentPath, keyName]);
}
Expand Down
2 changes: 1 addition & 1 deletion typegate/src/services/auth/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export async function ensureJWT(

const [context, nextAuth] = await auth.tokenMiddleware(
token,
new URL(request.url),
request,
);
if (nextAuth !== null) {
// "" is valid as it signal to remove the token
Expand Down
2 changes: 1 addition & 1 deletion typegate/src/services/auth/protocols/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class BasicAuth extends Protocol {

tokenMiddleware(
jwt: string,
_url: URL,
_request: Request,
): Promise<[Record<string, unknown>, string | null]> {
try {
const [username, token] = b64decode(jwt).split(
Expand Down
2 changes: 1 addition & 1 deletion typegate/src/services/auth/protocols/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export class InternalAuth extends Protocol {

async tokenMiddleware(
token: string,
_url: URL,
_request: Request,
): Promise<[Record<string, unknown>, string | null]> {
try {
const claims = await verifyJWT(token);
Expand Down
2 changes: 1 addition & 1 deletion typegate/src/services/auth/protocols/jwt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export class JWTAuth extends Protocol {

async tokenMiddleware(
token: string,
_url: URL,
_request: Request,
): Promise<[Record<string, unknown>, string | null]> {
try {
const claims = await jwt.verify(token, this.signKey);
Expand Down
32 changes: 24 additions & 8 deletions typegate/src/services/auth/protocols/oauth2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,26 @@ class AuthProfiler {
});
}

async transform(profile: any, url: string) {
async transform(
profile: any,
request: Request,
) {
const { tg, runtimeReferences } = this.authParameters;
const funcNode = tg.type(this.funcIndex, Type.FUNCTION);
const mat = tg.materializer(funcNode.materializer);
const runtime = runtimeReferences[mat.runtime];
const validatorInputWeak = generateWeakValidator(tg, funcNode.input);
const validatorOutput = generateValidator(tg, funcNode.output);

const input = { ...profile, _: { info: { url } } };
const input = {
...profile,
_: {
info: {
url: new URL(request.url),
headers: Object.fromEntries(request.headers.entries()),
},
},
};
validatorInputWeak(input);

// Note: this assumes func is a simple t.func(inp, out, mat)
Expand Down Expand Up @@ -154,7 +165,7 @@ export class OAuth2Auth extends Protocol {
this.typegraphName,
);
const tokens = await client.code.getToken(url, { state, codeVerifier });
const token = await this.createJWT(tokens);
const token = await this.createJWT(tokens, request);
const headers = await setEncryptedSessionCookie(
url.hostname,
this.typegraphName,
Expand Down Expand Up @@ -211,8 +222,9 @@ export class OAuth2Auth extends Protocol {

async tokenMiddleware(
token: string,
url: URL,
request: Request,
): Promise<[Record<string, unknown>, string | null]> {
const url = new URL(request.url);
const typegraphPath = `/${this.typegraphName}`;
const client = new OAuth2Client({
...this.clientData,
Expand All @@ -236,7 +248,7 @@ export class OAuth2Auth extends Protocol {
if (new Date().valueOf() / 1000 > claims.refreshAt) {
try {
const newClaims = await client.refreshToken.refresh(refreshToken);
const token = await this.createJWT(newClaims);
const token = await this.createJWT(newClaims, request);
return [
claims,
token ?? "", // token or clear
Expand All @@ -252,6 +264,7 @@ export class OAuth2Auth extends Protocol {

private async getProfile(
token: Tokens,
request: Request,
): Promise<null | Record<string, unknown>> {
if (!this.profileUrl) {
return null;
Expand All @@ -270,7 +283,7 @@ export class OAuth2Auth extends Protocol {
let profile = await res.json();

if (this.authProfiler) {
profile = await this.authProfiler!.transform(profile, url);
profile = await this.authProfiler!.transform(profile, request);
}

return profile;
Expand All @@ -279,8 +292,11 @@ export class OAuth2Auth extends Protocol {
}
}

private async createJWT(token: Tokens): Promise<string> {
const profile = await this.getProfile(token);
private async createJWT(
token: Tokens,
request: Request,
): Promise<string> {
const profile = await this.getProfile(token, request);
const profileClaims: ProfileClaims = profile
? mapKeys(profile, (k) => `profile.${k}`)
: {};
Expand Down
2 changes: 1 addition & 1 deletion typegate/src/services/auth/protocols/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export abstract class Protocol {

abstract tokenMiddleware(
token: string,
url: URL,
request: Request,
): Promise<[Record<string, unknown>, string | null]>;
}
Loading
Loading