Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions src/auth/AuthMiddleware.res
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@ let authenticateBearerToken = async (c: Hono.context, config: AuthTypes.Types.au
| Some(a) if String.startsWith(a, "Bearer ") => {
let token = String.substring(a, ~start=7, ~end=String.length(a))
try {
let payload = switch config.oidc {
| Some(oidc) => await Jwt.verifyJwt(token, (oidc :> Jwt.Types.oidcConfig))
| None => {
let decoded = Jwt.decodeJwt(token)
%raw(`decoded.payload`)
}
// Require OIDC config for OIDC/OAuth2 bearer-token methods. Without a
// JWKS source we have no way to verify signatures; failing closed is
// the only safe behaviour. Previously this branch silently decoded the
// payload of an unverified JWT and returned authenticated: true.
let oidc = switch config.oidc {
| Some(o) => o
| None =>
failwith(
"Bearer-token authentication requires OIDC configuration (no JWKS source to verify against)",
)
}
let payload = await Jwt.verifyJwt(token, (oidc :> Jwt.Types.oidcConfig))
let scope: option<string> = %raw(`payload.scope`)
{
authenticated: true,
method: #oidc,
subject: %raw("payload.sub"),
scopes: %raw("payload.scope")->Option.map(s => String.split(s, " "))->Option.getOr([]),
subject: payload.sub,
scopes: scope->Option.map(s => String.split(s, " "))->Option.getOr([]),
token: Obj.magic(payload),
}
} catch {
Expand Down Expand Up @@ -63,7 +69,7 @@ let authMiddleware = (config: AuthTypes.Types.authConfig) => {
| #none => {authenticated: true, method: #none}
| _ => {authenticated: false, method: #none, error: "Method not implemented"}
}

if res.authenticated {
result := res
}
Expand Down
11 changes: 7 additions & 4 deletions src/auth/AuthMiddleware.res.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import * as Jwt from "./Jwt.res.mjs";
import * as Belt_Array from "@rescript/runtime/lib/es6/Belt_Array.js";
import * as Stdlib_Option from "@rescript/runtime/lib/es6/Stdlib_Option.js";
import * as Pervasives from "@rescript/runtime/lib/es6/Pervasives.js";
import * as Core__Option from "@rescript/core/src/Core__Option.res.mjs";

let Hono = {};

Expand All @@ -24,13 +25,15 @@ async function authenticateBearerToken(c, config) {
}
let token = auth.substring(7, auth.length);
try {
let oidc = config.oidc;
let payload = oidc !== undefined ? await Jwt.verifyJwt(token, oidc) : (Jwt.decodeJwt(token), decoded.payload);
let o = config.oidc;
let oidc = o !== undefined ? o : Pervasives.failwith("Bearer-token authentication requires OIDC configuration (no JWKS source to verify against)");
let payload = await Jwt.verifyJwt(token, oidc);
let scope = payload.scope;
return {
authenticated: true,
method: "oidc",
subject: payload.sub,
scopes: Stdlib_Option.getOr(Stdlib_Option.map(payload.scope, s => s.split(" ")), []),
scopes: Core__Option.getOr(Core__Option.map(scope, s => s.split(" ")), []),
token: payload
};
} catch (exn) {
Expand Down
155 changes: 134 additions & 21 deletions src/auth/Jwt.res
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,19 @@ module Crypto = {
type cryptoKey
module Subtle = {
@val @scope(("crypto", "subtle"))
external importKey: (string, JSON.t, JSON.t, bool, array<string>) => promise<cryptoKey> =
external importKey: (JSON.t, JSON.t, JSON.t, bool, array<string>) => promise<cryptoKey> =
"importKey"

@val @scope(("crypto", "subtle"))
external verify: (JSON.t, cryptoKey, ArrayBuffer.t, ArrayBuffer.t) => promise<bool> = "verify"
external verify: (JSON.t, cryptoKey, Uint8Array.t, Uint8Array.t) => promise<bool> = "verify"
}
}

// TextEncoder for building the signing-input bytes.
type textEncoder
@new external makeTextEncoder: unit => textEncoder = "TextEncoder"
@send external textEncoderEncode: (textEncoder, string) => Uint8Array.t = "encode"

@val external atob: string => string = "atob"

let jwksCache: Map.t<string, Types.cachedJwks> = Map.make()
Expand All @@ -65,31 +70,49 @@ let fetchJwks = async (jwksUri: string): Types.jwks => {
failwith(`Failed to fetch JWKS: ${Int.toString(Fetch.Response.status(response))}`)
}
let jwks: Types.jwks = %raw("await response.json()")
Map.set(jwksCache, jwksUri, {jwks: jwks, expiresAt: now +. jwksCacheTtl})
Map.set(jwksCache, jwksUri, {jwks, expiresAt: now +. jwksCacheTtl})
jwks
}
}
}

// Base64url-decode a string into a Uint8Array.
let base64UrlDecode = (str: string): Uint8Array.t => {
let base64 = str
->String.replaceRegExp(%re("/-/g"), "+")
->String.replaceRegExp(%re("/_/g"), "/")
let mod4 = mod(String.length(base64), 4)
let padding = if mod4 > 0 { String.repeat("=", 4 - mod4) } else { "" }
let binary = atob(base64 ++ padding)
let bytes = Uint8Array.fromLength(String.length(binary))
for _i in 0 to String.length(binary) - 1 {
let _ = %raw("bytes[i] = binary.charCodeAt(i)")
let padding = if mod4 > 0 {
String.repeat("=", 4 - mod4)
} else {
""
}
let binary = atob(base64 ++ padding)
let len = String.length(binary)
let bytes = Uint8Array.fromLength(len)
%raw(`(function() { for (var i = 0; i < len; i++) { bytes[i] = binary.charCodeAt(i); } })()`)
bytes
}

let decodeJwt = (token: string) => {
// Result of decoding a JWT without verifying.
// Carries the raw base64url segments so verifyJwt can reconstruct the signing
// input.
type decoded = {
headerB64: string,
payloadB64: string,
signatureB64: string,
header: JSON.t,
payload: JSON.t,
}

let decodeJwt = (token: string): decoded => {
let parts = String.split(token, ".")
if Array.length(parts) != 3 {
failwith("Invalid JWT format")
}
let headerB64 = Array.getUnsafe(parts, 0)
let payloadB64 = Array.getUnsafe(parts, 1)
let signatureB64 = Array.getUnsafe(parts, 2)

let decodePart = (p: string): JSON.t => {
p
Expand All @@ -100,18 +123,94 @@ let decodeJwt = (token: string) => {
}

{
"header": decodePart(Array.getUnsafe(parts, 0)),
"payload": decodePart(Array.getUnsafe(parts, 1)),
headerB64,
payloadB64,
signatureB64,
header: decodePart(headerB64),
payload: decodePart(payloadB64),
}
}

// Map a JWT `alg` to a (importKey-algorithm, verify-algorithm) pair.
// 'none' and any unrecognised algorithm are rejected.
let algToWebCrypto = (alg: string): result<(JSON.t, JSON.t), string> => {
switch alg {
| "none" => Error("Algorithm 'none' is rejected for security reasons")
| "RS256" =>
Ok((
%raw(`{name: "RSASSA-PKCS1-v1_5", hash: "SHA-256"}`),
%raw(`{name: "RSASSA-PKCS1-v1_5"}`),
))
| "RS384" =>
Ok((
%raw(`{name: "RSASSA-PKCS1-v1_5", hash: "SHA-384"}`),
%raw(`{name: "RSASSA-PKCS1-v1_5"}`),
))
| "RS512" =>
Ok((
%raw(`{name: "RSASSA-PKCS1-v1_5", hash: "SHA-512"}`),
%raw(`{name: "RSASSA-PKCS1-v1_5"}`),
))
| "PS256" =>
Ok((
%raw(`{name: "RSA-PSS", hash: "SHA-256"}`),
%raw(`{name: "RSA-PSS", saltLength: 32}`),
))
| "PS384" =>
Ok((
%raw(`{name: "RSA-PSS", hash: "SHA-384"}`),
%raw(`{name: "RSA-PSS", saltLength: 48}`),
))
| "PS512" =>
Ok((
%raw(`{name: "RSA-PSS", hash: "SHA-512"}`),
%raw(`{name: "RSA-PSS", saltLength: 64}`),
))
| "ES256" =>
Ok((
%raw(`{name: "ECDSA", namedCurve: "P-256", hash: "SHA-256"}`),
%raw(`{name: "ECDSA", hash: "SHA-256"}`),
))
| "ES384" =>
Ok((
%raw(`{name: "ECDSA", namedCurve: "P-384", hash: "SHA-384"}`),
%raw(`{name: "ECDSA", hash: "SHA-384"}`),
))
| "ES512" =>
Ok((
%raw(`{name: "ECDSA", namedCurve: "P-521", hash: "SHA-512"}`),
%raw(`{name: "ECDSA", hash: "SHA-512"}`),
))
| "EdDSA" => Ok((%raw(`{name: "Ed25519"}`), %raw(`{name: "Ed25519"}`)))
| other => Error(`Unsupported JWT algorithm: ${other}`)
}
}

// Verify a JWT against the JWKS at `config.jwksUri`. Throws on:
// - malformed token
// - 'none' alg or unsupported alg
// - exp in the past
// - issuer mismatch
// - kid not found in JWKS
// - JWK→CryptoKey import failure
// - SIGNATURE INVALID (the central guarantee)
//
// Returns the payload only when the signature is valid.
let verifyJwt = async (token: string, config: Types.oidcConfig): Types.tokenPayload => {
let decoded = decodeJwt(token)

let alg: string = %raw(`decoded.header.alg`)
let kid: string = %raw(`decoded.header.kid`)

// Map alg first so we reject 'none' / unsupported BEFORE doing any other work.
let (importAlg, verifyAlg) = switch algToWebCrypto(alg) {
| Ok(pair) => pair
| Error(msg) => failwith(msg)
}

let payload: Types.tokenPayload = %raw(`decoded.payload`)
let _header = decoded["header"]

let now = Date.now() /. 1000.0

switch payload.exp {
| Some(exp) if exp < now => failwith("Token expired")
| _ => ()
Expand All @@ -122,15 +221,29 @@ let verifyJwt = async (token: string, config: Types.oidcConfig): Types.tokenPayl
}

let jwks = await fetchJwks(config.jwksUri)
let kid = %raw(`header.kid`)
let keyOpt = jwks.keys->Array.find(k => k.kid == kid)

switch keyOpt {
let jwk = switch keyOpt {
| None => failwith(`Key not found: ${kid}`)
| Some(_key) => {
// Import and verify logic here...
// (Simplified for now to match the scope of logic port)
payload
}
| Some(j) => j
}

// Import the JWK as a CryptoKey usable for verification only.
let jwkJson: JSON.t = Obj.magic(jwk)
let formatJwk: JSON.t = %raw(`"jwk"`)
let cryptoKey = await Crypto.Subtle.importKey(formatJwk, jwkJson, importAlg, false, ["verify"])

// Build the signing input: "<headerB64>.<payloadB64>" as UTF-8 bytes.
let signingInput =
makeTextEncoder()->textEncoderEncode(decoded.headerB64 ++ "." ++ decoded.payloadB64)

// Base64url-decode the signature segment.
let signatureBytes = base64UrlDecode(decoded.signatureB64)

// The central check.
let ok = await Crypto.Subtle.verify(verifyAlg, cryptoKey, signatureBytes, signingInput)
if !ok {
failwith("JWT signature verification failed")
}

payload
}
Loading
Loading