diff --git a/apps/dev/nextjs/prisma/migrations/20240124035029_init/migration.sql b/apps/dev/nextjs/prisma/migrations/20240124035029_init/migration.sql new file mode 100644 index 0000000000..bc3b8e013b --- /dev/null +++ b/apps/dev/nextjs/prisma/migrations/20240124035029_init/migration.sql @@ -0,0 +1,71 @@ +-- CreateTable +CREATE TABLE "Account" ( + "userId" TEXT NOT NULL, + "type" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "providerAccountId" TEXT NOT NULL, + "refresh_token" TEXT, + "access_token" TEXT, + "expires_at" INTEGER, + "token_type" TEXT, + "scope" TEXT, + "id_token" TEXT, + "session_state" TEXT, + + PRIMARY KEY ("provider", "providerAccountId"), + CONSTRAINT "Account_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateTable +CREATE TABLE "Session" ( + "id" TEXT NOT NULL PRIMARY KEY, + "sessionToken" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "expires" DATETIME NOT NULL, + CONSTRAINT "Session_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE RESTRICT ON UPDATE CASCADE +); + +-- CreateTable +CREATE TABLE "User" ( + "id" TEXT NOT NULL PRIMARY KEY, + "name" TEXT, + "email" TEXT, + "emailVerified" DATETIME, + "image" TEXT +); + +-- CreateTable +CREATE TABLE "VerificationToken" ( + "identifier" TEXT NOT NULL, + "token" TEXT NOT NULL, + "expires" DATETIME NOT NULL +); + +-- CreateTable +CREATE TABLE "Authenticator" ( + "id" TEXT NOT NULL PRIMARY KEY, + "credentialID" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "providerAccountId" TEXT NOT NULL, + "credentialPublicKey" TEXT NOT NULL, + "counter" INTEGER NOT NULL, + "credentialDeviceType" TEXT NOT NULL, + "credentialBackedUp" BOOLEAN NOT NULL, + "transports" TEXT, + CONSTRAINT "Authenticator_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateIndex +CREATE UNIQUE INDEX "Session_sessionToken_key" ON "Session"("sessionToken"); + +-- CreateIndex +CREATE UNIQUE INDEX "User_email_key" ON "User"("email"); + +-- CreateIndex +CREATE UNIQUE INDEX "VerificationToken_token_key" ON "VerificationToken"("token"); + +-- CreateIndex +CREATE UNIQUE INDEX "VerificationToken_identifier_token_key" ON "VerificationToken"("identifier", "token"); + +-- CreateIndex +CREATE UNIQUE INDEX "Authenticator_credentialID_key" ON "Authenticator"("credentialID"); diff --git a/apps/dev/nextjs/prisma/schema.prisma b/apps/dev/nextjs/prisma/schema.prisma index b905afedab..55cbd78290 100644 --- a/apps/dev/nextjs/prisma/schema.prisma +++ b/apps/dev/nextjs/prisma/schema.prisma @@ -8,7 +8,6 @@ generator client { } model Account { - id String @id @default(cuid()) userId String type String provider String @@ -20,9 +19,10 @@ model Account { scope String? id_token String? session_state String? - user User @relation(fields: [userId], references: [id]) - @@unique([provider, providerAccountId]) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@id([provider, providerAccountId]) } model Session { @@ -41,6 +41,7 @@ model User { image String? accounts Account[] sessions Session[] + Authenticator Authenticator[] } model VerificationToken { @@ -50,3 +51,17 @@ model VerificationToken { @@unique([identifier, token]) } + +model Authenticator { + id String @id @default(cuid()) + credentialID String @unique + userId String + providerAccountId String + credentialPublicKey String + counter Int + credentialDeviceType String + credentialBackedUp Boolean + transports String? + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) +} diff --git a/packages/adapter-hasura/schema.gql b/packages/adapter-hasura/schema.gql index 7334d7aa70..7caf7376aa 100644 --- a/packages/adapter-hasura/schema.gql +++ b/packages/adapter-hasura/schema.gql @@ -1192,6 +1192,7 @@ enum provider_type_enum { email oauth oidc + webauthn } """ diff --git a/packages/adapter-prisma/prisma/custom.prisma b/packages/adapter-prisma/prisma/custom.prisma index 6d62dbdb50..0f67d40919 100644 --- a/packages/adapter-prisma/prisma/custom.prisma +++ b/packages/adapter-prisma/prisma/custom.prisma @@ -17,6 +17,7 @@ model User { role String? accounts Account[] sessions Session[] + Authenticator Authenticator[] } model Account { @@ -51,3 +52,17 @@ model VerificationToken { @@id([identifier, token]) } + +model Authenticator { + id String @id @default(cuid()) + credentialID String @unique + userId String + providerAccountId String + credentialPublicKey String + counter Int + credentialDeviceType String + credentialBackedUp Boolean + transports String? + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) +} diff --git a/packages/adapter-prisma/prisma/mongodb.prisma b/packages/adapter-prisma/prisma/mongodb.prisma index 4f40809061..bb7f35d3de 100644 --- a/packages/adapter-prisma/prisma/mongodb.prisma +++ b/packages/adapter-prisma/prisma/mongodb.prisma @@ -41,6 +41,7 @@ model User { image String? accounts Account[] sessions Session[] + Authenticator Authenticator[] } model VerificationToken { @@ -51,3 +52,17 @@ model VerificationToken { @@unique([identifier, token]) } + +model Authenticator { + id String @id @default(auto()) @map("_id") @db.ObjectId + credentialID String @unique + userId String + providerAccountId String + credentialPublicKey String + counter Int + credentialDeviceType String + credentialBackedUp Boolean + transports String? + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) +} diff --git a/packages/adapter-prisma/prisma/schema.prisma b/packages/adapter-prisma/prisma/schema.prisma index d9f6605c99..8e084ba6fc 100644 --- a/packages/adapter-prisma/prisma/schema.prisma +++ b/packages/adapter-prisma/prisma/schema.prisma @@ -15,6 +15,7 @@ model User { image String? accounts Account[] sessions Session[] + Authenticator Authenticator[] } model Account { @@ -49,3 +50,17 @@ model VerificationToken { @@id([identifier, token]) } + +model Authenticator { + id String @id @default(cuid()) + credentialID String @unique + userId String + providerAccountId String + credentialPublicKey String + counter Int + credentialDeviceType String + credentialBackedUp Boolean + transports String? + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) +} diff --git a/packages/adapter-prisma/src/index.ts b/packages/adapter-prisma/src/index.ts index 2caa690fb7..122e86bada 100644 --- a/packages/adapter-prisma/src/index.ts +++ b/packages/adapter-prisma/src/index.ts @@ -16,7 +16,7 @@ * @module @auth/prisma-adapter */ import type { PrismaClient, Prisma } from "@prisma/client" -import type { Adapter, AdapterAccount } from "@auth/core/adapters" +import type { Adapter, AdapterAccount, AdapterAuthenticator, AdapterSession, AdapterUser } from "@auth/core/adapters" /** * ## Setup @@ -215,6 +215,20 @@ import type { Adapter, AdapterAccount } from "@auth/core/adapters" * @@unique([identifier, token]) * @@map("verificationtokens") * } + * + * model Authenticator { + * id String @id @default(cuid()) + * credentialID String @unique + * userId String + * providerAccountId String + * credentialPublicKey String + * counter Int + * credentialDeviceType String + * credentialBackedUp Boolean + * transports String? + * + * user User @relation(fields: [userId], references: [id], onDelete: Cascade) + * } * ``` * **/ @@ -234,10 +248,10 @@ export function PrismaAdapter( where: { provider_providerAccountId }, select: { user: true }, }) - return account?.user ?? null + return account?.user as AdapterUser ?? null }, - updateUser: ({ id, ...data }) => p.user.update({ where: { id }, data }), - deleteUser: (id) => p.user.delete({ where: { id } }), + updateUser: ({ id, ...data }) => p.user.update({ where: { id }, data }) as Promise, + deleteUser: (id) => p.user.delete({ where: { id } }) as Promise, linkAccount: (data) => p.account.create({ data }) as unknown as AdapterAccount, unlinkAccount: (provider_providerAccountId) => @@ -251,7 +265,7 @@ export function PrismaAdapter( }) if (!userAndSession) return null const { user, ...session } = userAndSession - return { user, session } + return { user, session } as { user: AdapterUser; session: AdapterSession } }, createSession: (data) => p.session.create({ data }), updateSession: (data) => @@ -280,5 +294,42 @@ export function PrismaAdapter( throw error } }, + async getAccount(providerAccountId, provider) { + return p.account.findFirst({ + where: { providerAccountId, provider } + }) as Promise + }, + async createAuthenticator(authenticator) { + return p.authenticator.create({ + data: authenticator + }).then(fromDBAuthenticator) + }, + async getAuthenticator(credentialID) { + const authenticator = await p.authenticator.findUnique({ where: { credentialID } }) + return authenticator ? fromDBAuthenticator(authenticator) : null + }, + async listAuthenticatorsByUserId(userId) { + const authenticators = await p.authenticator.findMany({ where: { userId } }) + + return authenticators.map(fromDBAuthenticator) + }, + async updateAuthenticatorCounter(credentialID, counter) { + return p.authenticator.update({ + where: { credentialID: credentialID }, + data: { counter }, + }).then(fromDBAuthenticator) + } + } +} + +type BasePrismaAuthenticator = Parameters[0]['data'] +type PrismaAuthenticator = BasePrismaAuthenticator & Required> + +function fromDBAuthenticator(authenticator: PrismaAuthenticator): AdapterAuthenticator { + const { transports, id, user, ...other } = authenticator + + return { + ...other, + transports: transports || undefined, } } diff --git a/packages/adapter-prisma/test/index.test.ts b/packages/adapter-prisma/test/index.test.ts index a1b2de7ae3..e4810a4375 100644 --- a/packages/adapter-prisma/test/index.test.ts +++ b/packages/adapter-prisma/test/index.test.ts @@ -8,6 +8,7 @@ const prisma = new PrismaClient().$extends(withAccelerate()) runBasicTests({ adapter: PrismaAdapter(prisma), + testWebAuthnMethods: true, db: { id() { if (process.env.CONTAINER_NAME !== "authjs-mongodb-test") return @@ -19,6 +20,7 @@ runBasicTests({ prisma.account.deleteMany({}), prisma.session.deleteMany({}), prisma.verificationToken.deleteMany({}), + prisma.authenticator.deleteMany({}), ]) }, disconnect: async () => { @@ -27,6 +29,7 @@ runBasicTests({ prisma.account.deleteMany({}), prisma.session.deleteMany({}), prisma.verificationToken.deleteMany({}), + prisma.authenticator.deleteMany({}), ]) await prisma.$disconnect() }, @@ -44,5 +47,7 @@ runBasicTests({ delete result.id return result }, + authenticator: (credentialID) => + prisma.authenticator.findUnique({ where: { credentialID } }), }, }) diff --git a/packages/adapter-surrealdb/src/index.ts b/packages/adapter-surrealdb/src/index.ts index 1776a4dc74..391368bd4e 100644 --- a/packages/adapter-surrealdb/src/index.ts +++ b/packages/adapter-surrealdb/src/index.ts @@ -31,7 +31,7 @@ export type AccountDoc = { userId: T refresh_token?: string access_token?: string - type: Extract + type: Extract provider: string providerAccountId: string expires_at?: number diff --git a/packages/utils/adapter.ts b/packages/utils/adapter.ts index 91fead6529..8d0da280ef 100644 --- a/packages/utils/adapter.ts +++ b/packages/utils/adapter.ts @@ -1,7 +1,7 @@ import { test, expect, beforeAll, afterAll } from "vitest" import type { Adapter } from "@auth/core/adapters" -import { createHash, randomUUID } from "crypto" +import { createHash, randomInt, randomUUID } from "crypto" export interface TestOptions { adapter: Adapter @@ -39,10 +39,18 @@ export interface TestOptions { * based on the user identifier and the verification token (hashed). */ verificationToken: (params: { identifier: string; token: string }) => any + /** + * A simple query function that returns an authenticator directly from the db. + */ + authenticator?: (credentialID: string) => any } - skipTests?: string[] + skipTests?: string[], + /** + * Enables testing of WebAuthn methods. + */ + testWebAuthnMethods?: boolean } -const testIf = (condition: boolean) => (condition ? test : test.skip) + /** * A wrapper to run the most basic tests. * Run this at the top of your test file. @@ -55,9 +63,25 @@ export async function runBasicTests(options: TestOptions) { await options.db.connect?.() }) - const { adapter: _adapter, db, skipTests } = options + const { adapter: _adapter, db, skipTests: skipTests = [], testWebAuthnMethods } = options const adapter = _adapter as Required + if (!testWebAuthnMethods) { + skipTests.push(...[ + "getAccount", + "getAuthenticator", + "createAuthenticator", + "listAuthenticatorsByUserId", + "updateAuthenticatorCounter" + ]) + } + + const maybeTest = ( + method: keyof Adapter, + ...args: Parameters extends [any, ...infer U] ? U : never + ) => + skipTests.includes(method) ? test.skip(method, ...args) : test(method, ...args) + afterAll(async () => { // @ts-expect-error This is only used for the TypeORM adapter await adapter.__disconnect?.() @@ -303,7 +327,7 @@ export async function runBasicTests(options: TestOptions) { expect(dbAccount).toBeNull() }) - testIf(!skipTests?.includes("deleteUser"))("deleteUser", async () => { + maybeTest("deleteUser", async () => { let dbUser = await db.user(user.id) expect(dbUser).toEqual(user) @@ -328,6 +352,221 @@ export async function runBasicTests(options: TestOptions) { // Account should not exist after user is deleted expect(dbAccount).toBeNull() }) + + maybeTest("getAccount", async () => { + // Setup + const providerAccountId = randomUUID() + const provider = "auth0" + const localUser = await adapter.createUser({ + id: randomUUID(), + email: "getAccount@example.com", + emailVerified: null, + }) + await adapter.linkAccount({ + provider, + providerAccountId, + type: "oauth", + userId: localUser.id, + }) + + // Test + const invalidBoth = await adapter.getAccount("invalid-provider-account-id", "invalid-provider") + expect(invalidBoth).toBeNull() + const invalidProvider = await adapter.getAccount(providerAccountId, "invalid-provider") + expect(invalidProvider).toBeNull() + const invalidProviderAccountId = await adapter.getAccount("invalid-provider-account-id", provider) + expect(invalidProviderAccountId).toBeNull() + const validAccount = await adapter.getAccount(providerAccountId, provider) + expect(validAccount).not.toBeNull() + + const dbAccount = await db.account({ + provider, + providerAccountId, + }) + expect(dbAccount).toMatchObject(validAccount || {}) + }) + maybeTest("createAuthenticator", async () => { + // Setup + const credentialID = randomUUID() + const localUser = await adapter.createUser({ + id: randomUUID(), + email: "createAuthenticator@example.com", + emailVerified: null, + }) + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID, + type: "webauthn", + userId: localUser.id, + }) + + // Test + const authenticatorData = { + credentialID, + providerAccountId: credentialID, + userId: localUser.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,ble,nfc", + } + const newAuthenticator = await adapter.createAuthenticator(authenticatorData) + expect(newAuthenticator).not.toBeNull() + expect(newAuthenticator).toMatchObject(authenticatorData) + + const dbAuthenticator = db.authenticator ? await db.authenticator( + credentialID, + ) : undefined + expect(dbAuthenticator).toMatchObject(newAuthenticator) + }) + maybeTest("getAuthenticator", async () => { + // Setup + const credentialID = randomUUID() + const localUser = await adapter.createUser({ + id: randomUUID(), + email: "getAuthenticator@example.com", + emailVerified: null, + }) + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID, + type: "webauthn", + userId: localUser.id, + }) + await adapter.createAuthenticator({ + credentialID, + providerAccountId: credentialID, + userId: localUser.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,ble,nfc", + }) + + // Test + const invalidAuthenticator = await adapter.getAuthenticator("invalid-credential-id") + expect(invalidAuthenticator).toBeNull() + + const validAuthenticator = await adapter.getAuthenticator(credentialID) + expect(validAuthenticator).not.toBeNull() + const dbAuthenticator = db.authenticator ? await db.authenticator( + credentialID + ) : undefined + expect(dbAuthenticator).toMatchObject(validAuthenticator || {}) + }) + maybeTest("listAuthenticatorsByUserId", async () => { + // Setup + const user1 = await adapter.createUser({ + id: randomUUID(), + email: "listAuthenticatorsByUserId1@example.com", + emailVerified: null, + }) + const user2 = await adapter.createUser({ + id: randomUUID(), + email: "listAuthenticatorsByUserId2@example.com", + emailVerified: null, + }) + const credentialID1 = randomUUID() + const credentialID2 = randomUUID() + const credentialID3 = randomUUID() + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID1, + type: "webauthn", + userId: user1.id, + }) + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID2, + type: "webauthn", + userId: user1.id, + }) + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID3, + type: "webauthn", + userId: user2.id, + }) + const authenticator1 = await adapter.createAuthenticator({ + credentialID: credentialID1, + providerAccountId: credentialID1, + userId: user1.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,ble,nfc", + }) + const authenticator2 = await adapter.createAuthenticator({ + credentialID: credentialID2, + providerAccountId: credentialID2, + userId: user1.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,nfc", + }) + const authenticator3 = await adapter.createAuthenticator({ + credentialID: credentialID3, + providerAccountId: credentialID3, + userId: user2.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,ble", + }) + + // Test + const authenticators0 = await adapter.listAuthenticatorsByUserId("invalid-user-id") + expect(authenticators0).toEqual([]) + + const authenticators1 = await adapter.listAuthenticatorsByUserId(user1.id) + expect(authenticators1).not.toBeNull() + expect([authenticator1, authenticator2]).toMatchObject(authenticators1 || []) + + const authenticators2 = await adapter.listAuthenticatorsByUserId(user2.id) + expect(authenticators2).not.toBeNull() + expect([authenticator3]).toMatchObject(authenticators2 || []) + }) + maybeTest("updateAuthenticatorCounter", async () => { + // Setup + const credentialID = randomUUID() + const localUser = await adapter.createUser({ + id: randomUUID(), + email: "updateAuthenticatorCounter@example.com", + emailVerified: null, + }) + await adapter.linkAccount({ + provider: "webauthn", + providerAccountId: credentialID, + type: "webauthn", + userId: localUser.id, + }) + const newAuthenticator = await adapter.createAuthenticator({ + credentialID, + providerAccountId: credentialID, + userId: localUser.id, + counter: randomInt(100), + credentialBackedUp: true, + credentialDeviceType: "platform", + credentialPublicKey: randomUUID(), + transports: "usb,ble,nfc", + }) + + // Test + await expect( + () => adapter.updateAuthenticatorCounter("invalid-credential-id", randomInt(100)) + ).rejects.toThrow() + + const newCounter = newAuthenticator.counter + randomInt(100) + const updatedAuthenticator = await adapter.updateAuthenticatorCounter(credentialID, newCounter) + expect(updatedAuthenticator).not.toBeNull() + expect(updatedAuthenticator.counter).toBe(newCounter) + }) } // UTILS