Skip to content

Commit 50ffecb

Browse files
feat(sso): add InResponseTo validation (#6557)
1 parent a0a1633 commit 50ffecb

File tree

7 files changed

+922
-9
lines changed

7 files changed

+922
-9
lines changed

docs/content/docs/plugins/sso.mdx

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,115 @@ mapping: {
707707
}
708708
```
709709

710+
## SAML Security
711+
712+
The SSO plugin includes optional security features to protect against common SAML vulnerabilities.
713+
714+
### AuthnRequest / InResponseTo Validation
715+
716+
You can enable InResponseTo validation for SP-initiated SAML flows. When enabled, the plugin tracks AuthnRequest IDs and validates the `InResponseTo` attribute in SAML responses. This prevents:
717+
718+
- **Unsolicited responses**: Responses not triggered by a legitimate login request
719+
- **Replay attacks**: Reusing old SAML responses
720+
- **Cross-provider injection**: Responses meant for a different provider
721+
722+
<Callout type="info">
723+
This feature is **opt-in** to ensure backward compatibility. Enable it explicitly for enhanced security.
724+
</Callout>
725+
726+
#### Enabling Validation (Single Instance)
727+
728+
For single-instance deployments, enable validation with the built-in in-memory store:
729+
730+
```ts title="auth.ts"
731+
import { betterAuth } from "better-auth";
732+
import { sso } from "@better-auth/sso";
733+
734+
const auth = betterAuth({
735+
plugins: [
736+
sso({
737+
saml: {
738+
// Enable InResponseTo validation
739+
enableInResponseToValidation: true,
740+
// Optionally reject IdP-initiated SSO (stricter security)
741+
allowIdpInitiated: false,
742+
// Custom TTL for AuthnRequest validity (default: 5 minutes)
743+
requestTTL: 10 * 60 * 1000, // 10 minutes
744+
},
745+
}),
746+
],
747+
});
748+
```
749+
750+
#### Options
751+
752+
| Option | Type | Default | Description |
753+
|--------|------|---------|-------------|
754+
| `enableInResponseToValidation` | `boolean` | `false` | Enable InResponseTo validation for SP-initiated flows. |
755+
| `allowIdpInitiated` | `boolean` | `true` | Allow IdP-initiated SSO (responses without InResponseTo). Set to `false` for stricter security. Only applies when validation is enabled. |
756+
| `requestTTL` | `number` | `300000` (5 min) | Time-to-live for AuthnRequest records in milliseconds. Requests older than this will be rejected. |
757+
| `authnRequestStore` | `AuthnRequestStore` | In-memory | Custom store implementation. Providing a custom store automatically enables validation. |
758+
759+
#### Multi-Instance Deployments (Production)
760+
761+
<Callout type="warning">
762+
For multi-instance deployments (load-balanced servers, serverless, etc.), you **must** provide a shared store like Redis. The default in-memory store only works for single-instance deployments.
763+
</Callout>
764+
765+
Providing a custom `authnRequestStore` automatically enables InResponseTo validation:
766+
767+
```ts title="auth.ts"
768+
import { betterAuth } from "better-auth";
769+
import { sso, type AuthnRequestStore } from "@better-auth/sso";
770+
771+
// Example Redis-backed store
772+
const redisAuthnRequestStore: AuthnRequestStore = {
773+
async save(record) {
774+
const ttl = Math.ceil((record.expiresAt - Date.now()) / 1000);
775+
await redis.set(
776+
`authn:${record.id}`,
777+
JSON.stringify(record),
778+
"EX",
779+
ttl
780+
);
781+
},
782+
async get(id) {
783+
const data = await redis.get(`authn:${id}`);
784+
if (!data) return null;
785+
const record = JSON.parse(data);
786+
if (record.expiresAt < Date.now()) {
787+
await redis.del(`authn:${id}`);
788+
return null;
789+
}
790+
return record;
791+
},
792+
async delete(id) {
793+
await redis.del(`authn:${id}`);
794+
},
795+
};
796+
797+
const auth = betterAuth({
798+
plugins: [
799+
sso({
800+
saml: {
801+
// Providing a store automatically enables validation
802+
authnRequestStore: redisAuthnRequestStore,
803+
// Optionally configure other options
804+
allowIdpInitiated: false,
805+
},
806+
}),
807+
],
808+
});
809+
```
810+
811+
#### Error Handling
812+
813+
When InResponseTo validation fails, users are redirected with an error query parameter:
814+
815+
- `?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID` — The request ID was not found or has expired
816+
- `?error=invalid_saml_response&error_description=Provider+mismatch` — The response was meant for a different provider
817+
- `?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed` — IdP-initiated SSO is disabled
818+
710819
## Domain verification
711820

712821
Domain verification allows your application to automatically trust a new SSO provider
@@ -965,6 +1074,32 @@ If you want to allow account linking for specific trusted providers, enable the
9651074
}
9661075
},
9671076
},
1077+
saml: {
1078+
description: "SAML security options for AuthnRequest/InResponseTo validation.",
1079+
type: "object",
1080+
properties: {
1081+
enableInResponseToValidation: {
1082+
description: "Enable InResponseTo validation for SP-initiated SAML flows. Opt-in for backward compatibility.",
1083+
type: "boolean",
1084+
default: false,
1085+
},
1086+
allowIdpInitiated: {
1087+
description: "Allow IdP-initiated SSO (unsolicited SAML responses). Set to false for stricter security. Only applies when validation is enabled.",
1088+
type: "boolean",
1089+
default: true,
1090+
},
1091+
requestTTL: {
1092+
description: "TTL for AuthnRequest records in milliseconds. Only applies when validation is enabled.",
1093+
type: "number",
1094+
default: 300000,
1095+
},
1096+
authnRequestStore: {
1097+
description: "Custom AuthnRequest store for multi-instance deployments (e.g., Redis). Providing a store automatically enables validation.",
1098+
type: "AuthnRequestStore",
1099+
required: false,
1100+
},
1101+
},
1102+
},
9681103
modelName: {
9691104
description: "The model name for the SSO provider table",
9701105
type: "string",
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/**
2+
* AuthnRequest Store
3+
*
4+
* Tracks SAML AuthnRequest IDs to enable InResponseTo validation.
5+
* This prevents:
6+
* - Unsolicited SAML responses
7+
* - Cross-provider response injection
8+
* - Replay attacks
9+
* - Expired login completions
10+
*/
11+
12+
export interface AuthnRequestRecord {
13+
id: string;
14+
providerId: string;
15+
createdAt: number;
16+
expiresAt: number;
17+
}
18+
19+
export interface AuthnRequestStore {
20+
save(record: AuthnRequestRecord): Promise<void>;
21+
get(id: string): Promise<AuthnRequestRecord | null>;
22+
delete(id: string): Promise<void>;
23+
}
24+
25+
/**
26+
* Default TTL for AuthnRequest records (5 minutes).
27+
* This should be sufficient for most IdPs while protecting against stale requests.
28+
*/
29+
export const DEFAULT_AUTHN_REQUEST_TTL_MS = 5 * 60 * 1000;
30+
31+
/**
32+
* In-memory implementation of AuthnRequestStore.
33+
* ⚠️ Only suitable for testing or single-instance non-serverless deployments.
34+
* For production, rely on the default behavior (uses verification table)
35+
* or provide a custom Redis-backed store.
36+
*/
37+
export function createInMemoryAuthnRequestStore(): AuthnRequestStore {
38+
const store = new Map<string, AuthnRequestRecord>();
39+
40+
const cleanup = () => {
41+
const now = Date.now();
42+
for (const [id, record] of store.entries()) {
43+
if (record.expiresAt < now) {
44+
store.delete(id);
45+
}
46+
}
47+
};
48+
49+
const cleanupInterval = setInterval(cleanup, 60 * 1000);
50+
51+
if (typeof cleanupInterval.unref === "function") {
52+
cleanupInterval.unref();
53+
}
54+
55+
return {
56+
async save(record: AuthnRequestRecord): Promise<void> {
57+
store.set(record.id, record);
58+
},
59+
60+
async get(id: string): Promise<AuthnRequestRecord | null> {
61+
const record = store.get(id);
62+
if (!record) {
63+
return null;
64+
}
65+
if (record.expiresAt < Date.now()) {
66+
store.delete(id);
67+
return null;
68+
}
69+
return record;
70+
},
71+
72+
async delete(id: string): Promise<void> {
73+
store.delete(id);
74+
},
75+
};
76+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import { describe, expect, it } from "vitest";
2+
import {
3+
createInMemoryAuthnRequestStore,
4+
DEFAULT_AUTHN_REQUEST_TTL_MS,
5+
} from "./authn-request-store";
6+
7+
describe("AuthnRequest Store", () => {
8+
describe("In-Memory Store", () => {
9+
it("should save and retrieve an AuthnRequest record", async () => {
10+
const store = createInMemoryAuthnRequestStore();
11+
12+
const record = {
13+
id: "_test-request-id-1",
14+
providerId: "saml-provider-1",
15+
createdAt: Date.now(),
16+
expiresAt: Date.now() + DEFAULT_AUTHN_REQUEST_TTL_MS,
17+
};
18+
19+
await store.save(record);
20+
const retrieved = await store.get(record.id);
21+
22+
expect(retrieved).toEqual(record);
23+
});
24+
25+
it("should return null for non-existent request ID", async () => {
26+
const store = createInMemoryAuthnRequestStore();
27+
28+
const retrieved = await store.get("_non-existent-id");
29+
30+
expect(retrieved).toBeNull();
31+
});
32+
33+
it("should return null for expired request ID", async () => {
34+
const store = createInMemoryAuthnRequestStore();
35+
36+
const record = {
37+
id: "_expired-request-id",
38+
providerId: "saml-provider-1",
39+
createdAt: Date.now() - 10000,
40+
expiresAt: Date.now() - 1000, // Already expired
41+
};
42+
43+
await store.save(record);
44+
const retrieved = await store.get(record.id);
45+
46+
expect(retrieved).toBeNull();
47+
});
48+
49+
it("should delete a request ID", async () => {
50+
const store = createInMemoryAuthnRequestStore();
51+
52+
const record = {
53+
id: "_delete-me",
54+
providerId: "saml-provider-1",
55+
createdAt: Date.now(),
56+
expiresAt: Date.now() + DEFAULT_AUTHN_REQUEST_TTL_MS,
57+
};
58+
59+
await store.save(record);
60+
await store.delete(record.id);
61+
62+
const retrieved = await store.get(record.id);
63+
expect(retrieved).toBeNull();
64+
});
65+
66+
it("should handle multiple providers with different request IDs", async () => {
67+
const store = createInMemoryAuthnRequestStore();
68+
69+
const record1 = {
70+
id: "_request-provider-1",
71+
providerId: "saml-provider-1",
72+
createdAt: Date.now(),
73+
expiresAt: Date.now() + DEFAULT_AUTHN_REQUEST_TTL_MS,
74+
};
75+
76+
const record2 = {
77+
id: "_request-provider-2",
78+
providerId: "saml-provider-2",
79+
createdAt: Date.now(),
80+
expiresAt: Date.now() + DEFAULT_AUTHN_REQUEST_TTL_MS,
81+
};
82+
83+
await store.save(record1);
84+
await store.save(record2);
85+
86+
const retrieved1 = await store.get(record1.id);
87+
const retrieved2 = await store.get(record2.id);
88+
89+
expect(retrieved1?.providerId).toBe("saml-provider-1");
90+
expect(retrieved2?.providerId).toBe("saml-provider-2");
91+
});
92+
});
93+
94+
describe("DEFAULT_AUTHN_REQUEST_TTL_MS", () => {
95+
it("should be 5 minutes in milliseconds", () => {
96+
expect(DEFAULT_AUTHN_REQUEST_TTL_MS).toBe(5 * 60 * 1000);
97+
});
98+
});
99+
});

packages/sso/src/index.ts

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import type { BetterAuthPlugin } from "better-auth";
22
import { XMLValidator } from "fast-xml-parser";
33
import * as saml from "samlify";
4+
import type {
5+
AuthnRequestRecord,
6+
AuthnRequestStore,
7+
} from "./authn-request-store";
8+
import {
9+
createInMemoryAuthnRequestStore,
10+
DEFAULT_AUTHN_REQUEST_TTL_MS,
11+
} from "./authn-request-store";
412
import {
513
requestDomainVerification,
614
verifyDomain,
@@ -16,6 +24,8 @@ import {
1624
import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "./types";
1725

1826
export type { SAMLConfig, OIDCConfig, SSOOptions, SSOProvider };
27+
export type { AuthnRequestStore, AuthnRequestRecord };
28+
export { createInMemoryAuthnRequestStore, DEFAULT_AUTHN_REQUEST_TTL_MS };
1929

2030
const fastValidator = {
2131
async validate(xml: string) {
@@ -71,19 +81,21 @@ export function sso<O extends SSOOptions>(
7181
};
7282

7383
export function sso<O extends SSOOptions>(options?: O | undefined): any {
84+
const optionsWithStore = options as O;
85+
7486
let endpoints = {
7587
spMetadata: spMetadata(),
76-
registerSSOProvider: registerSSOProvider(options as O),
77-
signInSSO: signInSSO(options as O),
78-
callbackSSO: callbackSSO(options as O),
79-
callbackSSOSAML: callbackSSOSAML(options as O),
80-
acsEndpoint: acsEndpoint(options as O),
88+
registerSSOProvider: registerSSOProvider(optionsWithStore),
89+
signInSSO: signInSSO(optionsWithStore),
90+
callbackSSO: callbackSSO(optionsWithStore),
91+
callbackSSOSAML: callbackSSOSAML(optionsWithStore),
92+
acsEndpoint: acsEndpoint(optionsWithStore),
8193
};
8294

8395
if (options?.domainVerification?.enabled) {
8496
const domainVerificationEndpoints = {
85-
requestDomainVerification: requestDomainVerification(options as O),
86-
verifyDomain: verifyDomain(options as O),
97+
requestDomainVerification: requestDomainVerification(optionsWithStore),
98+
verifyDomain: verifyDomain(optionsWithStore),
8799
};
88100

89101
endpoints = {

0 commit comments

Comments
 (0)