-
-
Notifications
You must be signed in to change notification settings - Fork 595
/
dependencies.py
243 lines (185 loc) · 7.75 KB
/
dependencies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import shutil
import tempfile
from collections.abc import AsyncGenerator, Callable, Generator
from pathlib import Path
from uuid import uuid4
import fastapi
from fastapi import BackgroundTasks, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm.session import Session
from mealie.core.config import get_app_dirs, get_app_settings
from mealie.db.db_setup import generate_session
from mealie.repos.all_repositories import get_repositories
from mealie.schema.user import PrivateUser, TokenData
from mealie.schema.user.user import DEFAULT_INTEGRATION_ID, GroupInDB
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
oauth2_scheme_soft_fail = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False)
ALGORITHM = "HS256"
app_dirs = get_app_dirs()
settings = get_app_settings()
async def is_logged_in(token: str = Depends(oauth2_scheme_soft_fail), session=Depends(generate_session)) -> bool:
"""
When you need to determine if the user is logged in, but don't need the user, you can use this
function to return a boolean value to represent if the user is logged in. No Auth exceptions are raised
if the user is not logged in. This behavior is not the same as 'get_current_user'
Args:
token (str, optional): [description]. Defaults to Depends(oauth2_scheme_soft_fail).
session ([type], optional): [description]. Defaults to Depends(generate_session).
Returns:
bool: True = Valid User / False = Not User
"""
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
long_token: str = payload.get("long_token")
if long_token is not None:
try:
if validate_long_live_token(session, token, payload.get("id")):
return True
except Exception:
return False
return user_id is not None
except Exception:
return False
async def get_public_group(group_slug: str = fastapi.Path(...), session=Depends(generate_session)) -> GroupInDB:
repos = get_repositories(session)
group = repos.groups.get_by_slug_or_id(group_slug)
if not group or group.preferences.private_group or not group.preferences.recipe_public:
raise HTTPException(404, "group not found")
else:
return group
async def try_get_current_user(
request: Request,
token: str = Depends(oauth2_scheme_soft_fail),
session=Depends(generate_session),
) -> PrivateUser | None:
try:
return await get_current_user(request, token, session)
except Exception:
return None
async def get_current_user(
request: Request, token: str | None = Depends(oauth2_scheme_soft_fail), session=Depends(generate_session)
) -> PrivateUser:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
if token is None and "mealie.access_token" in request.cookies:
# Try extract from cookie
token = request.cookies.get("mealie.access_token", "")
else:
token = token or ""
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
long_token: str = payload.get("long_token")
if long_token is not None:
return validate_long_live_token(session, token, payload.get("id"))
if user_id is None:
raise credentials_exception
token_data = TokenData(user_id=user_id)
except JWTError as e:
raise credentials_exception from e
repos = get_repositories(session)
user = repos.users.get_one(token_data.user_id, "id", any_case=False)
# If we don't commit here, lazy-loads from user relationships will leave some table lock in postgres
# which can cause quite a bit of pain further down the line
session.commit()
if user is None:
raise credentials_exception
return user
async def get_integration_id(token: str = Depends(oauth2_scheme)) -> str:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
decoded_token = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
return decoded_token.get("integration_id", DEFAULT_INTEGRATION_ID)
except JWTError as e:
raise credentials_exception from e
async def get_admin_user(current_user: PrivateUser = Depends(get_current_user)) -> PrivateUser:
if not current_user.admin:
raise HTTPException(status.HTTP_403_FORBIDDEN)
return current_user
def validate_long_live_token(session: Session, client_token: str, user_id: str) -> PrivateUser:
repos = get_repositories(session)
token = repos.api_tokens.multi_query({"token": client_token, "user_id": user_id})
try:
return token[0].user
except IndexError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED) from e
def validate_file_token(token: str | None = None) -> Path:
"""
Args:
token (Optional[str], optional): _description_. Defaults to None.
Raises:
HTTPException: 400 Bad Request when no token or the file doesn't exist
HTTPException: 401 Unauthorized when the token is invalid
"""
if not token:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
file_path = Path(payload.get("file"))
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if not file_path.exists():
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return file_path
def validate_recipe_token(token: str | None = None) -> str:
"""
Args:
token (Optional[str], optional): _description_. Defaults to None.
Raises:
HTTPException: 400 Bad Request when no token or the recipe doesn't exist
HTTPException: 401 JWTError when token is invalid
Returns:
str: token data
"""
if not token:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
try:
payload = jwt.decode(token, settings.SECRET, algorithms=[ALGORITHM])
slug: str | None = payload.get("slug")
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="could not validate file token",
) from e
if slug is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
return slug
async def temporary_zip_path() -> AsyncGenerator[Path, None]:
app_dirs.TEMP_DIR.mkdir(exist_ok=True, parents=True)
temp_path = app_dirs.TEMP_DIR.joinpath("my_zip_archive.zip")
try:
yield temp_path
finally:
temp_path.unlink(missing_ok=True)
async def temporary_dir(background_tasks: BackgroundTasks) -> AsyncGenerator[Path, None]:
temp_path = app_dirs.TEMP_DIR.joinpath(uuid4().hex)
temp_path.mkdir(exist_ok=True, parents=True)
try:
yield temp_path
finally:
background_tasks.add_task(shutil.rmtree, temp_path)
def temporary_file(ext: str = "") -> Callable[[], Generator[tempfile._TemporaryFileWrapper, None, None]]:
"""
Returns a temporary file with the specified extension
"""
def func():
temp_path = app_dirs.TEMP_DIR.joinpath(uuid4().hex + ext)
temp_path.touch()
with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f:
try:
yield f
finally:
temp_path.unlink(missing_ok=True)
return func