Skip to content

Commit 2e10984

Browse files
committed
Implement access token strategy db adapter
1 parent 7b19595 commit 2e10984

File tree

7 files changed

+241
-43
lines changed

7 files changed

+241
-43
lines changed

fastapi_users_db_sqlalchemy/__init__.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,16 @@
11
"""FastAPI Users database adapter for SQLAlchemy + encode/databases."""
2-
import uuid
32
from typing import Mapping, Optional, Type
43

54
from databases import Database
65
from fastapi_users.db.base import BaseUserDatabase
76
from fastapi_users.models import UD
87
from pydantic import UUID4
98
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, func, select
10-
from sqlalchemy.dialects.postgresql import UUID
119
from sqlalchemy.ext.declarative import declared_attr
12-
from sqlalchemy.types import CHAR, TypeDecorator
1310

14-
__version__ = "1.0.0"
15-
16-
17-
class GUID(TypeDecorator): # pragma: no cover
18-
"""Platform-independent GUID type.
11+
from fastapi_users_db_sqlalchemy.guid import GUID
1912

20-
Uses PostgreSQL's UUID type, otherwise uses
21-
CHAR(36), storing as regular strings.
22-
"""
23-
24-
class UUIDChar(CHAR):
25-
python_type = UUID4
26-
27-
impl = UUIDChar
28-
29-
def load_dialect_impl(self, dialect):
30-
if dialect.name == "postgresql":
31-
return dialect.type_descriptor(UUID())
32-
else:
33-
return dialect.type_descriptor(CHAR(36))
34-
35-
def process_bind_param(self, value, dialect):
36-
if value is None:
37-
return value
38-
elif dialect.name == "postgresql":
39-
return str(value)
40-
else:
41-
if not isinstance(value, uuid.UUID):
42-
return str(uuid.UUID(value))
43-
else:
44-
return str(value)
45-
46-
def process_result_value(self, value, dialect):
47-
if value is None:
48-
return value
49-
else:
50-
if not isinstance(value, uuid.UUID):
51-
value = uuid.UUID(value)
52-
return value
13+
__version__ = "1.0.0"
5314

5415

5516
class SQLAlchemyBaseUserTable:
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from datetime import datetime
2+
from typing import Generic, Optional, Type
3+
4+
from databases import Database
5+
from fastapi_users.authentication.strategy.db import A, AccessTokenDatabase
6+
from sqlalchemy import Column, DateTime, ForeignKey, String, Table
7+
from sqlalchemy.ext.declarative import declared_attr
8+
9+
from fastapi_users_db_sqlalchemy.guid import GUID
10+
11+
12+
class SQLAlchemyBaseAccessTokenTable:
13+
"""Base SQLAlchemy access token table definition."""
14+
15+
__tablename__ = "accesstoken"
16+
17+
token = Column(String(length=43), primary_key=True)
18+
created_at = Column(DateTime(timezone=True), index=True, nullable=False)
19+
20+
@declared_attr
21+
def user_id(cls):
22+
return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False)
23+
24+
25+
class SQLAlchemyAccessTokenDatabase(AccessTokenDatabase, Generic[A]):
26+
"""
27+
Access token database adapter for SQLAlchemy.
28+
29+
:param access_token_model: Pydantic model of a DB representation of an access token.
30+
:param database: `Database` instance from `encode/databases`.
31+
:param access_tokens: SQLAlchemy access token table instance.
32+
"""
33+
34+
def __init__(
35+
self, access_token_model: Type[A], database: Database, access_tokens: Table
36+
):
37+
self.access_token_model = access_token_model
38+
self.database = database
39+
self.access_tokens = access_tokens
40+
41+
async def get_by_token(
42+
self, token: str, max_age: Optional[datetime] = None
43+
) -> Optional[A]:
44+
query = self.access_tokens.select().where(self.access_tokens.c.token == token)
45+
if max_age is not None:
46+
query = query.where(self.access_tokens.c.created_at >= max_age)
47+
48+
access_token = await self.database.fetch_one(query)
49+
if access_token is not None:
50+
return self.access_token_model(**access_token)
51+
return None
52+
53+
async def create(self, access_token: A) -> A:
54+
query = self.access_tokens.insert()
55+
await self.database.execute(query, access_token.dict())
56+
return access_token
57+
58+
async def update(self, access_token: A) -> A:
59+
update_query = (
60+
self.access_tokens.update()
61+
.where(self.access_tokens.c.token == access_token.token)
62+
.values(access_token.dict())
63+
)
64+
await self.database.execute(update_query)
65+
return access_token
66+
67+
async def delete(self, access_token: A) -> None:
68+
query = self.access_tokens.delete().where(
69+
self.access_tokens.c.token == access_token.token
70+
)
71+
await self.database.execute(query)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import uuid
2+
3+
from pydantic import UUID4
4+
from sqlalchemy.dialects.postgresql import UUID
5+
from sqlalchemy.types import CHAR, TypeDecorator
6+
7+
8+
class GUID(TypeDecorator): # pragma: no cover
9+
"""Platform-independent GUID type.
10+
11+
Uses PostgreSQL's UUID type, otherwise uses
12+
CHAR(36), storing as regular strings.
13+
"""
14+
15+
class UUIDChar(CHAR):
16+
python_type = UUID4
17+
18+
impl = UUIDChar
19+
20+
def load_dialect_impl(self, dialect):
21+
if dialect.name == "postgresql":
22+
return dialect.type_descriptor(UUID())
23+
else:
24+
return dialect.type_descriptor(CHAR(36))
25+
26+
def process_bind_param(self, value, dialect):
27+
if value is None:
28+
return value
29+
elif dialect.name == "postgresql":
30+
return str(value)
31+
else:
32+
if not isinstance(value, uuid.UUID):
33+
return str(uuid.UUID(value))
34+
else:
35+
return str(value)
36+
37+
def process_result_value(self, value, dialect):
38+
if value is None:
39+
return value
40+
else:
41+
if not isinstance(value, uuid.UUID):
42+
value = uuid.UUID(value)
43+
return value

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
description-file = "README.md"
2323
requires-python = ">=3.7"
2424
requires = [
25-
"fastapi-users >= 6.1.2",
25+
"fastapi-users >= 9.1.0",
2626
"sqlalchemy >=1.4",
2727
"databases >=0.5"
2828
]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
fastapi-users >= 6.1.2
1+
fastapi-users >= 9.1.0
22
sqlalchemy >=1.4
33
databases[postgresql, sqlite] >=0.5

tests/test_access_token.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import sqlite3
2+
import uuid
3+
from datetime import datetime, timedelta, timezone
4+
from typing import AsyncGenerator
5+
6+
import pytest
7+
import sqlalchemy
8+
from databases import Database
9+
from fastapi_users.authentication.strategy.db.models import BaseAccessToken
10+
from pydantic import UUID4
11+
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
12+
13+
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable
14+
from fastapi_users_db_sqlalchemy.access_token import (
15+
SQLAlchemyAccessTokenDatabase,
16+
SQLAlchemyBaseAccessTokenTable,
17+
)
18+
19+
20+
class AccessToken(BaseAccessToken):
21+
pass
22+
23+
24+
@pytest.fixture
25+
def user_id() -> UUID4:
26+
return uuid.uuid4()
27+
28+
29+
@pytest.fixture
30+
async def sqlalchemy_access_token_db(
31+
user_id: UUID4,
32+
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
33+
Base: DeclarativeMeta = declarative_base()
34+
35+
class AccessTokenTable(SQLAlchemyBaseAccessTokenTable, Base):
36+
pass
37+
38+
class UserTable(SQLAlchemyBaseUserTable, Base):
39+
pass
40+
41+
DATABASE_URL = "sqlite:///./test-sqlalchemy-access-token.db"
42+
database = Database(DATABASE_URL)
43+
44+
engine = sqlalchemy.create_engine(
45+
DATABASE_URL, connect_args={"check_same_thread": False}
46+
)
47+
Base.metadata.create_all(engine)
48+
49+
await database.connect()
50+
51+
# Create user
52+
query = UserTable.__table__.insert()
53+
await database.execute(
54+
query,
55+
{
56+
"id": user_id,
57+
"email": "lancelot@camelot.bt",
58+
"hashed_password": "guinevere",
59+
"is_active": True,
60+
"is_verified": False,
61+
"is_superuser": False,
62+
},
63+
)
64+
65+
yield SQLAlchemyAccessTokenDatabase(
66+
AccessToken, database, AccessTokenTable.__table__
67+
)
68+
69+
Base.metadata.drop_all(engine)
70+
await database.disconnect()
71+
72+
73+
@pytest.mark.asyncio
74+
@pytest.mark.db
75+
async def test_queries(
76+
sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken],
77+
user_id: UUID4,
78+
):
79+
access_token = AccessToken(token="TOKEN", user_id=user_id)
80+
81+
# Create
82+
access_token_db = await sqlalchemy_access_token_db.create(access_token)
83+
assert access_token_db.token == "TOKEN"
84+
assert access_token_db.user_id == user_id
85+
86+
# Update
87+
access_token_db.created_at = datetime.now(timezone.utc)
88+
await sqlalchemy_access_token_db.update(access_token_db)
89+
90+
# Get by token
91+
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
92+
access_token_db.token
93+
)
94+
assert access_token_by_token is not None
95+
96+
# Get by token expired
97+
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
98+
access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
99+
)
100+
assert access_token_by_token is None
101+
102+
# Get by token not expired
103+
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
104+
access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
105+
)
106+
assert access_token_by_token is not None
107+
108+
# Get by token unknown
109+
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
110+
"NOT_EXISTING_TOKEN"
111+
)
112+
assert access_token_by_token is None
113+
114+
# Exception when inserting existing token
115+
with pytest.raises(sqlite3.IntegrityError):
116+
await sqlalchemy_access_token_db.create(access_token_db)
117+
118+
# Delete token
119+
await sqlalchemy_access_token_db.delete(access_token_db)
120+
deleted_access_token = await sqlalchemy_access_token_db.get_by_token(
121+
access_token_db.token
122+
)
123+
assert deleted_access_token is None
File renamed without changes.

0 commit comments

Comments
 (0)