Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(adapter-drizzle): add option to pass in schema #8561

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 23 additions & 6 deletions packages/adapter-drizzle/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ import { BaseSQLiteDatabase, SQLiteTableFn } from "drizzle-orm/sqlite-core"
import { mySqlDrizzleAdapter } from "./lib/mysql.js"
import { pgDrizzleAdapter } from "./lib/pg.js"
import { SQLiteDrizzleAdapter } from "./lib/sqlite.js"
import { SqlFlavorOptions, TableFn } from "./lib/utils.js"
import {
ClientFlavors,
MinimumSchema,
SqlFlavorOptions,
TableFn,
} from "./lib/utils.js"
import { is } from "drizzle-orm"

import type { Adapter } from "@auth/core/adapters"
Expand All @@ -48,8 +53,11 @@ import type { Adapter } from "@auth/core/adapters"
* ```
*
* :::info
* If you're using multi-project schemas, you can pass your table function as a second argument
* If you're using multi-project schemas, you can pass your table function as a second argument.
* Alternatively, you can pass your tables as an object if your tables includes other
* attributes you want to be returned from the adapter.
* :::
*
*
* ## Setup
*
Expand Down Expand Up @@ -253,14 +261,23 @@ import type { Adapter } from "@auth/core/adapters"
**/
export function DrizzleAdapter<SqlFlavor extends SqlFlavorOptions>(
db: SqlFlavor,
table?: TableFn<SqlFlavor>
tableFnOrTables?: TableFn<SqlFlavor> | Partial<ClientFlavors<SqlFlavor>>
): Adapter {
if (is(db, MySqlDatabase)) {
return mySqlDrizzleAdapter(db, table as MySqlTableFn)
return mySqlDrizzleAdapter(
db,
tableFnOrTables as MySqlTableFn | MinimumSchema["mysql"] | undefined
)
} else if (is(db, PgDatabase)) {
return pgDrizzleAdapter(db, table as PgTableFn)
return pgDrizzleAdapter(
db,
tableFnOrTables as PgTableFn | MinimumSchema["pg"] | undefined
)
} else if (is(db, BaseSQLiteDatabase)) {
return SQLiteDrizzleAdapter(db, table as SQLiteTableFn)
return SQLiteDrizzleAdapter(
db,
tableFnOrTables as SQLiteTableFn | MinimumSchema["sqlite"] | undefined
)
}

throw new Error(
Expand Down
14 changes: 11 additions & 3 deletions packages/adapter-drizzle/src/lib/mysql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from "drizzle-orm/mysql-core"

import type { Adapter, AdapterAccount } from "@auth/core/adapters"
import { MinimumSchema } from "./utils"

export function createTables(mySqlTable: MySqlTableFn) {
const users = mySqlTable("user", {
Expand Down Expand Up @@ -78,10 +79,17 @@ export type DefaultSchema = ReturnType<typeof createTables>

export function mySqlDrizzleAdapter(
client: InstanceType<typeof MySqlDatabase>,
tableFn = defaultMySqlTableFn
tableFnOrTables?: MySqlTableFn | Partial<MinimumSchema["mysql"]>
): Adapter {
const { users, accounts, sessions, verificationTokens } =
createTables(tableFn)
const defaultTables = createTables(
typeof tableFnOrTables === "function"
? tableFnOrTables
: defaultMySqlTableFn
)
const { users, accounts, sessions, verificationTokens } = {
...defaultTables,
...(typeof tableFnOrTables === "object" ? tableFnOrTables : {}),
}

return {
async createUser(data) {
Expand Down
12 changes: 9 additions & 3 deletions packages/adapter-drizzle/src/lib/pg.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from "drizzle-orm/pg-core"

import type { Adapter, AdapterAccount } from "@auth/core/adapters"
import { MinimumSchema } from "./utils"

export function createTables(pgTable: PgTableFn) {
const users = pgTable("user", {
Expand Down Expand Up @@ -69,10 +70,15 @@ export type DefaultSchema = ReturnType<typeof createTables>

export function pgDrizzleAdapter(
client: InstanceType<typeof PgDatabase>,
tableFn = defaultPgTableFn
tableFnOrTables?: PgTableFn | Partial<MinimumSchema["pg"]>
): Adapter {
const { users, accounts, sessions, verificationTokens } =
createTables(tableFn)
const defaultTables = createTables(
typeof tableFnOrTables === "function" ? tableFnOrTables : defaultPgTableFn
)
const { users, accounts, sessions, verificationTokens } = {
...defaultTables,
...(typeof tableFnOrTables === "object" ? tableFnOrTables : {}),
}

return {
async createUser(data) {
Expand Down
14 changes: 11 additions & 3 deletions packages/adapter-drizzle/src/lib/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
} from "drizzle-orm/sqlite-core"

import type { Adapter, AdapterAccount } from "@auth/core/adapters"
import { MinimumSchema } from "./utils"

export function createTables(sqliteTable: SQLiteTableFn) {
const users = sqliteTable("user", {
Expand Down Expand Up @@ -68,10 +69,17 @@ export type DefaultSchema = ReturnType<typeof createTables>

export function SQLiteDrizzleAdapter(
client: InstanceType<typeof BaseSQLiteDatabase>,
tableFn = defaultSqliteTableFn
tableFnOrTables?: SQLiteTableFn | Partial<MinimumSchema["sqlite"]>
): Adapter {
const { users, accounts, sessions, verificationTokens } =
createTables(tableFn)
const defaultTables = createTables(
typeof tableFnOrTables === "function"
? tableFnOrTables
: defaultSqliteTableFn
)
const { users, accounts, sessions, verificationTokens } = {
...defaultTables,
...(typeof tableFnOrTables === "object" ? tableFnOrTables : {}),
}

return {
async createUser(data) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import type { Config } from "drizzle-kit"

export default {
schema: "./tests/mysql/schema.ts",
out: "./tests/mysql/.drizzle",
driver: "mysql2",
dbCredentials: {
host: "localhost",
user: "root",
password: "password",
database: "next-auth",
},
} satisfies Config
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { runBasicTests } from "../../../adapter-test"
import { DrizzleAdapter } from "../../src"
import { db, sessions, verificationTokens, accounts, users } from "./schema"
import { eq, and } from "drizzle-orm"
import { fixtures } from "../fixtures"

globalThis.crypto ??= require("node:crypto").webcrypto

runBasicTests({
adapter: DrizzleAdapter(db, { users }),
fixtures,
db: {
connect: async () => {
await Promise.all([
db.delete(sessions),
db.delete(accounts),
db.delete(verificationTokens),
db.delete(users),
])
},
disconnect: async () => {
await Promise.all([
db.delete(sessions),
db.delete(accounts),
db.delete(verificationTokens),
db.delete(users),
])
},
user: async (id) => {
const user = await db
.select()
.from(users)
.where(eq(users.id, id))
.then((res) => res[0] ?? null)
return user
},
session: async (sessionToken) => {
const session = await db
.select()
.from(sessions)
.where(eq(sessions.sessionToken, sessionToken))
.then((res) => res[0] ?? null)

return session
},
account: (provider_providerAccountId) => {
const account = db
.select()
.from(accounts)
.where(
eq(
accounts.providerAccountId,
provider_providerAccountId.providerAccountId
)
)
.then((res) => res[0] ?? null)
return account
},
verificationToken: (identifier_token) =>
db
.select()
.from(verificationTokens)
.where(
and(
eq(verificationTokens.token, identifier_token.token),
eq(verificationTokens.identifier, identifier_token.identifier)
)
)
.then((res) => res[0]) ?? null,
},
})
30 changes: 30 additions & 0 deletions packages/adapter-drizzle/tests/mysql-custom-tables/schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { mysqlTable, timestamp, varchar } from "drizzle-orm/mysql-core"
import { drizzle } from "drizzle-orm/mysql2"
import { createPool } from "mysql2"
import { createTables } from "../../src/lib/mysql"

const poolConnection = createPool({
host: "localhost",
user: "root",
Dismissed Show dismissed Hide dismissed
password: "password",
database: "next-auth",
})

export const { accounts, sessions, verificationTokens } =
createTables(mysqlTable)
export const users = mysqlTable("user", {
id: varchar("id", { length: 255 }).notNull().primaryKey(),
name: varchar("name", { length: 255 }),
email: varchar("email", { length: 255 }).notNull(),
emailVerified: timestamp("emailVerified", {
mode: "date",
fsp: 3,
}).defaultNow(),
image: varchar("image", { length: 255 }),

// Some other attribute we wan't returned in the session callback
foo: varchar("foo", { length: 255 }),
})
export const schema = { users, accounts, sessions, verificationTokens }

export const db = drizzle(poolConnection, { schema })
22 changes: 22 additions & 0 deletions packages/adapter-drizzle/tests/mysql-custom-tables/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env bash

echo "Initializing container for MySQL tests."

MYSQL_DATABASE=next-auth
MYSQL_ROOT_PASSWORD=password
MYSQL_CONTAINER_NAME=next-auth-mysql-test

docker run -d --rm \
-e MYSQL_DATABASE=${MYSQL_DATABASE} \
-e MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD} \
--name "${MYSQL_CONTAINER_NAME}" \
-p 3306:3306 \
mysql:8 \
--default-authentication-plugin=mysql_native_password

echo "Waiting 15 sec for db to start..." && sleep 15

drizzle-kit generate:mysql --config=./tests/mysql/drizzle.config.ts
drizzle-kit push:mysql --config=./tests/mysql/drizzle.config.ts
jest ./tests/mysql/index.test.ts --forceExit
docker stop ${MYSQL_CONTAINER_NAME}