11"""FastAPI Users database adapter for SQLModel."""
22import uuid
3- from typing import Callable , Generic , Optional , Type , TypeVar
3+ from typing import Generic , Optional , Type , TypeVar
44
55from fastapi_users .db .base import BaseUserDatabase
66from fastapi_users .models import BaseOAuthAccount , BaseUserDB
77from pydantic import UUID4 , EmailStr
8- from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession
9- from sqlalchemy .future import Engine
10- from sqlalchemy .orm import selectinload , sessionmaker
8+ from sqlalchemy .ext .asyncio import AsyncSession
9+ from sqlalchemy .orm import selectinload
1110from sqlmodel import Field , Session , SQLModel , func , select
1211
1312__version__ = "0.0.3"
@@ -48,80 +47,74 @@ class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]):
4847 Database adapter for SQLModel.
4948
5049 :param user_db_model: SQLModel model of a DB representation of a user.
51- :param engine : SQLAlchemy engine .
50+ :param session : SQLAlchemy session .
5251 """
5352
54- engine : Engine
53+ session : Session
5554 oauth_account_model : Optional [Type [OA ]]
5655
5756 def __init__ (
5857 self ,
5958 user_db_model : Type [UD ],
60- engine : Engine ,
59+ session : Session ,
6160 oauth_account_model : Optional [Type [OA ]] = None ,
6261 ):
6362 super ().__init__ (user_db_model )
64- self .engine = engine
63+ self .session = session
6564 self .oauth_account_model = oauth_account_model
6665
6766 async def get (self , id : UUID4 ) -> Optional [UD ]:
6867 """Get a single user by id."""
69- with Session (self .engine ) as session :
70- return session .get (self .user_db_model , id )
68+ return self .session .get (self .user_db_model , id )
7169
7270 async def get_by_email (self , email : str ) -> Optional [UD ]:
7371 """Get a single user by email."""
74- with Session (self .engine ) as session :
75- statement = select (self .user_db_model ).where (
76- func .lower (self .user_db_model .email ) == func .lower (email )
77- )
78- results = session .exec (statement )
79- return results .first ()
72+ statement = select (self .user_db_model ).where (
73+ func .lower (self .user_db_model .email ) == func .lower (email )
74+ )
75+ results = self .session .exec (statement )
76+ return results .first ()
8077
8178 async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
8279 """Get a single user by OAuth account id."""
8380 if not self .oauth_account_model :
8481 raise NotSetOAuthAccountTableError ()
85- with Session (self .engine ) as session :
86- statement = (
87- select (self .oauth_account_model )
88- .where (self .oauth_account_model .oauth_name == oauth )
89- .where (self .oauth_account_model .account_id == account_id )
90- )
91- results = session .exec (statement )
92- oauth_account = results .first ()
93- if oauth_account :
94- user = oauth_account .user # type: ignore
95- return user
96- return None
82+ statement = (
83+ select (self .oauth_account_model )
84+ .where (self .oauth_account_model .oauth_name == oauth )
85+ .where (self .oauth_account_model .account_id == account_id )
86+ )
87+ results = self .session .exec (statement )
88+ oauth_account = results .first ()
89+ if oauth_account :
90+ user = oauth_account .user # type: ignore
91+ return user
92+ return None
9793
9894 async def create (self , user : UD ) -> UD :
9995 """Create a user."""
100- with Session (self .engine ) as session :
101- session .add (user )
102- if self .oauth_account_model is not None :
103- for oauth_account in user .oauth_accounts : # type: ignore
104- session .add (oauth_account )
105- session .commit ()
106- session .refresh (user )
107- return user
96+ self .session .add (user )
97+ if self .oauth_account_model is not None :
98+ for oauth_account in user .oauth_accounts : # type: ignore
99+ self .session .add (oauth_account )
100+ self .session .commit ()
101+ self .session .refresh (user )
102+ return user
108103
109104 async def update (self , user : UD ) -> UD :
110105 """Update a user."""
111- with Session (self .engine ) as session :
112- session .add (user )
113- if self .oauth_account_model is not None :
114- for oauth_account in user .oauth_accounts : # type: ignore
115- session .add (oauth_account )
116- session .commit ()
117- session .refresh (user )
118- return user
106+ self .session .add (user )
107+ if self .oauth_account_model is not None :
108+ for oauth_account in user .oauth_accounts : # type: ignore
109+ self .session .add (oauth_account )
110+ self .session .commit ()
111+ self .session .refresh (user )
112+ return user
119113
120114 async def delete (self , user : UD ) -> None :
121115 """Delete a user."""
122- with Session (self .engine ) as session :
123- session .delete (user )
124- session .commit ()
116+ self .session .delete (user )
117+ self .session .commit ()
125118
126119
127120class SQLModelUserDatabaseAsync (Generic [UD , OA ], BaseUserDatabase [UD ]):
@@ -132,81 +125,72 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]):
132125 :param engine: SQLAlchemy async engine.
133126 """
134127
135- engine : AsyncEngine
128+ session : AsyncSession
136129 oauth_account_model : Optional [Type [OA ]]
137130
138131 def __init__ (
139132 self ,
140133 user_db_model : Type [UD ],
141- engine : AsyncEngine ,
134+ session : AsyncSession ,
142135 oauth_account_model : Optional [Type [OA ]] = None ,
143136 ):
144137 super ().__init__ (user_db_model )
145- self .engine = engine
138+ self .session = session
146139 self .oauth_account_model = oauth_account_model
147- self .session_maker : Callable [[], AsyncSession ] = sessionmaker (
148- self .engine , class_ = AsyncSession , expire_on_commit = False
149- )
150140
151141 async def get (self , id : UUID4 ) -> Optional [UD ]:
152142 """Get a single user by id."""
153- async with self .session_maker () as session :
154- return await session .get (self .user_db_model , id )
143+ return await self .session .get (self .user_db_model , id )
155144
156145 async def get_by_email (self , email : str ) -> Optional [UD ]:
157146 """Get a single user by email."""
158- async with self .session_maker () as session :
159- statement = select (self .user_db_model ).where (
160- func .lower (self .user_db_model .email ) == func .lower (email )
161- )
162- results = await session .execute (statement )
163- object = results .first ()
164- if object is None :
165- return None
166- return object [0 ]
147+ statement = select (self .user_db_model ).where (
148+ func .lower (self .user_db_model .email ) == func .lower (email )
149+ )
150+ results = await self .session .execute (statement )
151+ object = results .first ()
152+ if object is None :
153+ return None
154+ return object [0 ]
167155
168156 async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
169157 """Get a single user by OAuth account id."""
170158 if not self .oauth_account_model :
171159 raise NotSetOAuthAccountTableError ()
172- async with self .session_maker () as session :
173- statement = (
174- select (self .oauth_account_model )
175- .where (self .oauth_account_model .oauth_name == oauth )
176- .where (self .oauth_account_model .account_id == account_id )
177- .options (selectinload (self .oauth_account_model .user )) # type: ignore
178- )
179- results = await session .execute (statement )
180- oauth_account = results .first ()
181- if oauth_account :
182- user = oauth_account [0 ].user
183- return user
184- return None
160+ statement = (
161+ select (self .oauth_account_model )
162+ .where (self .oauth_account_model .oauth_name == oauth )
163+ .where (self .oauth_account_model .account_id == account_id )
164+ .options (selectinload (self .oauth_account_model .user )) # type: ignore
165+ )
166+ results = await self .session .execute (statement )
167+ oauth_account = results .first ()
168+ if oauth_account :
169+ user = oauth_account [0 ].user
170+ return user
171+ return None
185172
186173 async def create (self , user : UD ) -> UD :
187174 """Create a user."""
188- async with self .session_maker () as session :
189- session .add (user )
190- if self .oauth_account_model is not None :
191- for oauth_account in user .oauth_accounts : # type: ignore
192- session .add (oauth_account )
193- await session .commit ()
194- await session .refresh (user )
195- return user
175+ self .session .add (user )
176+ if self .oauth_account_model is not None :
177+ for oauth_account in user .oauth_accounts : # type: ignore
178+ self .session .add (oauth_account )
179+ await self .session .commit ()
180+ await self .session .refresh (user )
181+ return user
196182
197183 async def update (self , user : UD ) -> UD :
198184 """Update a user."""
199- async with self .session_maker () as session :
200- session .add (user )
201- if self .oauth_account_model is not None :
202- for oauth_account in user .oauth_accounts : # type: ignore
203- session .add (oauth_account )
204- await session .commit ()
205- await session .refresh (user )
206- return user
185+ self .session .add (user )
186+ if self .oauth_account_model is not None :
187+ for oauth_account in user .oauth_accounts : # type: ignore
188+ self .session .add (oauth_account )
189+ await self .session .commit ()
190+ await self .session .refresh (user )
191+ return user
207192
208193 async def delete (self , user : UD ) -> None :
209194 """Delete a user."""
210- async with self .session_maker () as session :
211- await session .delete (user )
212- await session .commit ()
195+ await self .session .delete (user )
196+ await self .session .commit ()
0 commit comments