Skip to content

Commit

Permalink
chore(backend): Refactor existing cross-origin check - drop forwarded…
Browse files Browse the repository at this point in the history
…Port
  • Loading branch information
dimkl committed Jul 19, 2023
1 parent 6a9a128 commit 623194d
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 77 deletions.
7 changes: 3 additions & 4 deletions packages/backend/src/tokens/interstitialRule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ export const nonBrowserRequestInDevRule: InterstitialRule = options => {
};

export const crossOriginRequestWithoutHeader: InterstitialRule = options => {
const { origin, host, forwardedHost, forwardedPort, forwardedProto } = options;
const { origin, host, forwardedHost, forwardedProto } = options;
const isCrossOrigin =
origin &&
checkCrossOrigin({
originURL: new URL(origin),
host,
forwardedHost,
forwardedPort,
forwardedProto,
});

Expand Down Expand Up @@ -80,9 +79,9 @@ export const potentialFirstLoadInDevWhenUATMissing: InterstitialRule = options =
* It is expected that a primary app will trigger a redirect back to the satellite app.
*/
export const potentialRequestAfterSignInOrOutFromClerkHostedUiInDev: InterstitialRule = options => {
const { apiKey, secretKey, referrer, host, forwardedHost, forwardedPort, forwardedProto } = options;
const { apiKey, secretKey, referrer, host, forwardedHost, forwardedProto } = options;
const crossOriginReferrer =
referrer && checkCrossOrigin({ originURL: new URL(referrer), host, forwardedHost, forwardedPort, forwardedProto });
referrer && checkCrossOrigin({ originURL: new URL(referrer), host, forwardedHost, forwardedProto });
const key = secretKey || apiKey || '';

if (isDevelopmentFromApiKey(key) && crossOriginReferrer) {
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/src/tokens/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ export default (QUnit: QUnit) => {
const requestState = await authenticateRequest({
...defaultMockAuthenticateRequestOptions,
origin: 'https://clerk.com',
forwardedProto: '80',
forwardedProto: 'http',
cookieToken: mockJwt,
});

Expand Down
22 changes: 9 additions & 13 deletions packages/backend/src/util/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ export default (QUnit: QUnit) => {
assert.true(checkCrossOrigin({ originURL, host, forwardedHost, forwardedProto }));
});

test('is CO when HTTPS to HTTP with forwarded port', assert => {
test('is CO when HTTPS to HTTP with forwarded proto', assert => {
const originURL = new URL('https://localhost');
const host = new URL('http://localhost').host;
const forwardedPort = '80';
const forwardedProto = 'http';

assert.true(checkCrossOrigin({ originURL, host, forwardedPort }));
assert.true(checkCrossOrigin({ originURL, host, forwardedProto }));
});

test('is CO with cross origin auth domain', assert => {
Expand All @@ -56,9 +56,8 @@ export default (QUnit: QUnit) => {

test('is CO when forwarded port overrides host derived port', assert => {
const originURL = new URL('https://localhost:443');
const host = new URL('https://localhost').host;
const forwardedPort = '3001';
assert.true(checkCrossOrigin({ originURL, host, forwardedPort }));
const host = new URL('https://localhost:3001').host;
assert.true(checkCrossOrigin({ originURL, host }));
});

test('is not CO with port included in x-forwarded host', assert => {
Expand All @@ -80,26 +79,23 @@ export default (QUnit: QUnit) => {
test('is not CO when forwarded port and origin does not contain a port - http', assert => {
const originURL = new URL('http://localhost');
const host = new URL('http://localhost').host;
const forwardedPort = '80';

assert.false(checkCrossOrigin({ originURL, host, forwardedPort }));
assert.false(checkCrossOrigin({ originURL, host }));
});

test('is not CO when forwarded port and origin does not contain a port - https', assert => {
const originURL = new URL('https://localhost');
const host = originURL.host;
const forwardedPort = '443';
const host = new URL('https://localhost').host;

assert.false(checkCrossOrigin({ originURL, host, forwardedPort }));
assert.false(checkCrossOrigin({ originURL, host }));
});

test('is not CO based on referrer with forwarded host & port and referer', assert => {
const host = '';
const forwardedPort = '80';
const forwardedHost = 'example.com';
const referrer = 'http://example.com/';

assert.false(checkCrossOrigin({ originURL: new URL(referrer), host, forwardedPort, forwardedHost }));
assert.false(checkCrossOrigin({ originURL: new URL(referrer), host, forwardedHost }));
});

test('is not CO for AWS', assert => {
Expand Down
56 changes: 3 additions & 53 deletions packages/backend/src/util/request.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { buildOrigin } from '../utils';
/**
* This function is only used in the case where:
* - DevOrStaging key is present
Expand All @@ -9,45 +10,15 @@ export function checkCrossOrigin({
originURL,
host,
forwardedHost,
forwardedPort,
forwardedProto,
}: {
originURL: URL;
host?: string | null;
forwardedHost?: string | null;
forwardedPort?: string | null;
forwardedProto?: string | null;
}) {
const fwdProto = getFirstValueFromHeaderValue(forwardedProto);
let fwdPort = getFirstValueFromHeaderValue(forwardedPort);

// If forwardedPort mismatch with forwardedProto determine forwardedPort
// from forwardedProto as fallback (if exists)
// This check fixes the Railway App issue
const fwdProtoHasMoreValuesThanFwdPorts =
(forwardedProto || '').split(',').length > (forwardedPort || '').split(',').length;
if (fwdProto && fwdProtoHasMoreValuesThanFwdPorts) {
fwdPort = getPortFromProtocol(fwdProto);
}

const originProtocol = getProtocolVerb(originURL.protocol);
if (fwdProto && fwdProto !== originProtocol) {
return true;
}

const protocol = fwdProto || originProtocol;
/* The forwarded host prioritised over host to be checked against the referrer. */
const finalURL = convertHostHeaderValueToURL(forwardedHost || host || undefined, protocol);
finalURL.port = fwdPort || finalURL.port;

if (getPort(finalURL) !== getPort(originURL)) {
return true;
}
if (finalURL.hostname !== originURL.hostname) {
return true;
}

return false;
const finalURL = buildOrigin({ forwardedProto, forwardedHost, protocol: originURL.protocol, host });
return finalURL && new URL(finalURL).origin !== originURL.origin;
}

export function convertHostHeaderValueToURL(host?: string, protocol = 'https'): URL {
Expand All @@ -58,27 +29,6 @@ export function convertHostHeaderValueToURL(host?: string, protocol = 'https'):
return new URL(`${protocol}://${host}`);
}

const PROTOCOL_TO_PORT_MAPPING: Record<string, string> = {
http: '80',
https: '443',
} as const;

function getPort(url: URL) {
return url.port || getPortFromProtocol(url.protocol);
}

function getPortFromProtocol(protocol: string) {
return PROTOCOL_TO_PORT_MAPPING[protocol];
}

function getFirstValueFromHeaderValue(value?: string | null) {
return value?.split(',')[0]?.trim() || '';
}

function getProtocolVerb(protocol: string) {
return protocol?.replace(/:$/, '') || '';
}

type ErrorFields = {
message: string;
long_message: string;
Expand Down
30 changes: 24 additions & 6 deletions packages/backend/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import { constants } from './constants';

const getHeader = (req: Request, key: string) => req.headers.get(key);
const getFirstValueFromHeader = (req: Request, key: string) => getHeader(req, key)?.split(',')[0];
const getFirstValueFromHeader = (value?: string | null) => value?.split(',')[0];

type BuildRequestUrl = (request: Request, path?: string) => URL;
export const buildRequestUrl: BuildRequestUrl = (request, path) => {
const initialUrl = new URL(request.url);

const forwardedProto = getFirstValueFromHeader(request, constants.Headers.ForwardedProto);
const forwardedHost = getFirstValueFromHeader(request, constants.Headers.ForwardedHost);
const forwardedProto = getHeader(request, constants.Headers.ForwardedProto);
const forwardedHost = getHeader(request, constants.Headers.ForwardedHost);
const host = getHeader(request, constants.Headers.Host);
const protocol = initialUrl.protocol;

const resolvedHost = forwardedHost ?? host ?? initialUrl.host;
const resolvedProtocol = forwardedProto ?? initialUrl.protocol.replace(/[:/]/, '');
const base = buildOrigin({ protocol, forwardedProto, forwardedHost, host: host || initialUrl.host });

return new URL(path || initialUrl.pathname, `${resolvedProtocol}://${resolvedHost}`);
return new URL(path || initialUrl.pathname, base);
};

type BuildOriginParams = {
protocol?: string;
forwardedProto?: string | null;
forwardedHost?: string | null;
host?: string | null;
};
type BuildOrigin = (params: BuildOriginParams) => string;
export const buildOrigin: BuildOrigin = ({ protocol, forwardedProto, forwardedHost, host }) => {
const resolvedHost = getFirstValueFromHeader(forwardedHost) ?? host;
const resolvedProtocol = getFirstValueFromHeader(forwardedProto) ?? protocol?.replace(/[:/]/, '');

if (!resolvedHost || !resolvedProtocol) {
return '';
}

return `${resolvedProtocol}://${resolvedHost}`;
};

0 comments on commit 623194d

Please sign in to comment.