Skip to content

Commit

Permalink
pyright for api_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Oct 24, 2023
1 parent b7ef0f9 commit 4020e60
Show file tree
Hide file tree
Showing 20 changed files with 189 additions and 155 deletions.
2 changes: 1 addition & 1 deletion api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
app.include_router(client_router, prefix="/api/client", tags=["Client"])

# requests from the GUI
app.include_router(gui_router, prefix="/api/gui", tags=["GUI"])
app.include_router(gui_router, prefix="/api/gui", tags=["GUI"])
5 changes: 2 additions & 3 deletions api_helpers/clients/_get_mongo_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
from motor.motor_asyncio import AsyncIOMotorClient
from ..core.settings import get_settings

Expand All @@ -8,13 +7,13 @@ def _get_mongo_client() -> AsyncIOMotorClient:
# We want one async mongo client per event loop
loop = asyncio.get_event_loop()
if hasattr(loop, '_mongo_client'):
return loop._mongo_client
return loop._mongo_client # type: ignore

# Otherwise, create a new client and store it in the global variable
mongo_uri = get_settings().MONGO_URI
if mongo_uri is None:
print('MONGO_URI environment variable not set')
raise Exception("MONGO_URI environment variable not set")
raise KeyError("MONGO_URI environment variable not set")

client = AsyncIOMotorClient(mongo_uri)
setattr(loop, '_mongo_client', client)
Expand Down
25 changes: 14 additions & 11 deletions api_helpers/clients/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
async def fetch_projects_for_user(user_id: Union[str, None]) -> List[ProtocaasProject]:
client = _get_mongo_client()
projects_collection = client['protocaas']['projects']
projects = await projects_collection.find({}).to_list(length=None)
projects = await projects_collection.find({}).to_list(length=None) # type: ignore
for project in projects:
_remove_id_field(project)
projects = [ProtocaasProject(**project) for project in projects] # validate projects
Expand All @@ -26,13 +26,13 @@ async def fetch_projects_with_tag(tag: str) -> List[ProtocaasProject]:
projects = await projects_collection.find({
# When you use a query like { "tags": tag } against an array field in MongoDB, it checks if any element of the array matches the value.
'tags': tag
}).to_list(length=None)
}).to_list(length=None) # type: ignore
for project in projects:
_remove_id_field(project)
projects = [ProtocaasProject(**project) for project in projects] # validate projects
return projects

async def fetch_project(project_id: str) -> ProtocaasProject:
async def fetch_project(project_id: str) -> Union[ProtocaasProject, None]:
client = _get_mongo_client()
projects_collection = client['protocaas']['projects']
project = await projects_collection.find_one({'projectId': project_id})
Expand All @@ -45,7 +45,7 @@ async def fetch_project(project_id: str) -> ProtocaasProject:
async def fetch_project_files(project_id: str) -> List[ProtocaasFile]:
client = _get_mongo_client()
files_collection = client['protocaas']['files']
files = await files_collection.find({'projectId': project_id}).to_list(length=None)
files = await files_collection.find({'projectId': project_id}).to_list(length=None) # type: ignore
for file in files:
_remove_id_field(file)
files = [ProtocaasFile(**file) for file in files] # validate files
Expand All @@ -54,7 +54,7 @@ async def fetch_project_files(project_id: str) -> List[ProtocaasFile]:
async def fetch_project_jobs(project_id: str, include_private_keys=False) -> List[ProtocaasJob]:
client = _get_mongo_client()
jobs_collection = client['protocaas']['jobs']
jobs = await jobs_collection.find({'projectId': project_id}).to_list(length=None)
jobs = await jobs_collection.find({'projectId': project_id}).to_list(length=None) # type: ignore
for job in jobs:
_remove_id_field(job)
jobs = [ProtocaasJob(**job) for job in jobs] # validate jobs
Expand Down Expand Up @@ -114,7 +114,7 @@ async def fetch_compute_resource(compute_resource_id: str):
async def fetch_compute_resources_for_user(user_id: str):
client = _get_mongo_client()
compute_resources_collection = client['protocaas']['computeResources']
compute_resources = await compute_resources_collection.find({'ownerId': user_id}).to_list(length=None)
compute_resources = await compute_resources_collection.find({'ownerId': user_id}).to_list(length=None) # type: ignore
for compute_resource in compute_resources:
_remove_id_field(compute_resource)
compute_resources = [ProtocaasComputeResource(**compute_resource) for compute_resource in compute_resources] # validate compute resources
Expand Down Expand Up @@ -142,7 +142,7 @@ async def register_compute_resource(compute_resource_id: str, name: str, user_id

compute_resource = await compute_resources_collection.find_one({'computeResourceId': compute_resource_id})
if compute_resource is not None:
compute_resources_collection.update_one({'computeResourceId': compute_resource_id}, {
await compute_resources_collection.update_one({'computeResourceId': compute_resource_id}, {
'$set': {
'ownerId': user_id,
'name': name,
Expand All @@ -157,7 +157,7 @@ async def register_compute_resource(compute_resource_id: str, name: str, user_id
timestampCreated=time.time(),
apps=[]
)
compute_resources_collection.insert_one(new_compute_resource.dict(exclude_none=True))
await compute_resources_collection.insert_one(new_compute_resource.dict(exclude_none=True))

async def fetch_compute_resource_jobs(compute_resource_id: str, statuses: Union[List[str], None], include_private_keys: bool) -> List[ProtocaasJob]:
client = _get_mongo_client()
Expand All @@ -166,11 +166,11 @@ async def fetch_compute_resource_jobs(compute_resource_id: str, statuses: Union[
jobs = await jobs_collection.find({
'computeResourceId': compute_resource_id,
'status': {'$in': statuses}
}).to_list(length=None)
}).to_list(length=None) # type: ignore
else:
jobs = await jobs_collection.find({
'computeResourceId': compute_resource_id
}).to_list(length=None)
}).to_list(length=None) # type: ignore
for job in jobs:
_remove_id_field(job)
jobs = [ProtocaasJob(**job) for job in jobs] # validate jobs
Expand All @@ -197,12 +197,15 @@ async def update_compute_resource_node(compute_resource_id: str, compute_resourc
}
}, upsert=True)

class ComputeResourceNotFoundError(Exception):
pass

async def set_compute_resource_spec(compute_resource_id: str, spec: ComputeResourceSpec):
client = _get_mongo_client()
compute_resources_collection = client['protocaas']['computeResources']
compute_resource = await compute_resources_collection.find_one({'computeResourceId': compute_resource_id})
if compute_resource is None:
raise Exception(f"No compute resource with ID {compute_resource_id}")
raise ComputeResourceNotFoundError(f"No compute resource with ID {compute_resource_id}")
await compute_resources_collection.update_one({'computeResourceId': compute_resource_id}, {
'$set': {
'spec': spec.dict(exclude_none=True)
Expand Down
7 changes: 5 additions & 2 deletions api_helpers/clients/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import json
import aiohttp
import urllib.parse
from ..core.settings import get_settings


class PubsubError(Exception):
pass

async def publish_pubsub_message(*, channel: str, message: dict):
settings = get_settings()
# see https://www.pubnub.com/docs/sdks/rest-api/publish-message-to-channel
Expand All @@ -23,5 +26,5 @@ async def publish_pubsub_message(*, channel: str, message: dict):
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status != 200:
raise Exception(f"Error publishing to pubsub: {resp.status} {resp.text}")
raise PubsubError(f"Error publishing to pubsub: {resp.status} {resp.text}")
return True
19 changes: 10 additions & 9 deletions api_helpers/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel
import os

Expand All @@ -7,19 +8,19 @@

class Settings(BaseModel):
# General app config
MONGO_URI: str = os.environ.get("MONGO_URI")
MONGO_URI: Optional[str] = os.environ.get("MONGO_URI")

PUBNUB_SUBSCRIBE_KEY: str = os.environ.get("VITE_PUBNUB_SUBSCRIBE_KEY")
PUBNUB_PUBLISH_KEY: str = os.environ.get("PUBNUB_PUBLISH_KEY")
PUBNUB_SUBSCRIBE_KEY: Optional[str] = os.environ.get("VITE_PUBNUB_SUBSCRIBE_KEY")
PUBNUB_PUBLISH_KEY: Optional[str] = os.environ.get("PUBNUB_PUBLISH_KEY")

GITHUB_CLIENT_ID: str = os.environ.get("VITE_GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET: str = os.environ.get("GITHUB_CLIENT_SECRET")
GITHUB_CLIENT_ID: Optional[str] = os.environ.get("VITE_GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET: Optional[str] = os.environ.get("GITHUB_CLIENT_SECRET")

DEFAULT_COMPUTE_RESOURCE_ID: str = os.environ.get("VITE_DEFAULT_COMPUTE_RESOURCE_ID")
DEFAULT_COMPUTE_RESOURCE_ID: Optional[str] = os.environ.get("VITE_DEFAULT_COMPUTE_RESOURCE_ID")

OUTPUT_BUCKET_URI: str = os.environ.get("OUTPUT_BUCKET_URI")
OUTPUT_BUCKET_CREDENTIALS: str = os.environ.get("OUTPUT_BUCKET_CREDENTIALS")
OUTPUT_BUCKET_BASE_URL: str = os.environ.get("OUTPUT_BUCKET_BASE_URL")
OUTPUT_BUCKET_URI: Optional[str] = os.environ.get("OUTPUT_BUCKET_URI")
OUTPUT_BUCKET_CREDENTIALS: Optional[str] = os.environ.get("OUTPUT_BUCKET_CREDENTIALS")
OUTPUT_BUCKET_BASE_URL: Optional[str] = os.environ.get("OUTPUT_BUCKET_BASE_URL")

def get_settings():
return Settings()
11 changes: 7 additions & 4 deletions api_helpers/routers/client/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ class GetProjectResponse(BaseModel):
project: ProtocaasProject
success: bool

class ProjectError(Exception):
pass

@router.get("/projects/{project_id}")
async def get_project(project_id) -> GetProjectResponse:
try:
project = await fetch_project(project_id)
if project is None:
raise Exception(f"No project with ID {project_id}")
raise ProjectError(f"No project with ID {project_id}")
return GetProjectResponse(project=project, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

# get project files
class GetProjectFilesResponse(BaseModel):
Expand All @@ -35,7 +38,7 @@ async def get_project_files(project_id) -> GetProjectFilesResponse:
return GetProjectFilesResponse(files=files, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

# get project jobs
class GetProjectJobsResponse(BaseModel):
Expand All @@ -49,4 +52,4 @@ async def get_project_jobs(project_id) -> GetProjectJobsResponse:
return GetProjectJobsResponse(jobs=jobs, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
28 changes: 19 additions & 9 deletions api_helpers/routers/compute_resource/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class GetAppsResponse(BaseModel):
apps: List[ProtocaasComputeResourceApp]
success: bool

class ComputeResourceNotFoundException(Exception):
pass

@router.get("/compute_resources/{compute_resource_id}/apps")
async def compute_resource_get_apps(
compute_resource_id: str,
Expand All @@ -32,12 +35,12 @@ async def compute_resource_get_apps(

compute_resource = await fetch_compute_resource(compute_resource_id)
if compute_resource is None:
raise Exception(f"No compute resource with ID {compute_resource_id}")
raise ComputeResourceNotFoundException(f"No compute resource with ID {compute_resource_id}")
apps = compute_resource.apps
return GetAppsResponse(apps=apps, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

# get pubsub subscription
class GetPubsubSubscriptionResponse(BaseModel):
Expand All @@ -62,10 +65,10 @@ async def compute_resource_get_pubsub_subscription(

compute_resource = await fetch_compute_resource(compute_resource_id)
if compute_resource is None:
raise Exception(f"No compute resource with ID {compute_resource_id}")
raise ComputeResourceNotFoundException(f"No compute resource with ID {compute_resource_id}")
VITE_PUBNUB_SUBSCRIBE_KEY = get_settings().PUBNUB_SUBSCRIBE_KEY
if VITE_PUBNUB_SUBSCRIBE_KEY is None:
raise Exception('Environment variable not set: VITE_PUBNUB_SUBSCRIBE_KEY')
raise KeyError('Environment variable not set: VITE_PUBNUB_SUBSCRIBE_KEY')
subscription = PubsubSubscription(
pubnubSubscribeKey=VITE_PUBNUB_SUBSCRIBE_KEY,
pubnubChannel=compute_resource_id,
Expand All @@ -74,7 +77,7 @@ async def compute_resource_get_pubsub_subscription(
return GetPubsubSubscriptionResponse(subscription=subscription, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

# get unfinished jobs
class GetUnfinishedJobsResponse(BaseModel):
Expand Down Expand Up @@ -110,7 +113,7 @@ async def compute_resource_get_unfinished_jobs(
return GetUnfinishedJobsResponse(jobs=jobs, success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

# set spec
class SetSpecRequest(BaseModel):
Expand Down Expand Up @@ -143,7 +146,14 @@ async def compute_resource_set_spec(
return SetSpecResponse(success=True)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e)) from e

class UnexpectedException(Exception):
pass

class InvalidSignatureException(Exception):
pass


def _authenticate_compute_resource_request(
compute_resource_id: str,
Expand All @@ -152,6 +162,6 @@ def _authenticate_compute_resource_request(
expected_payload: str
):
if compute_resource_payload != expected_payload:
raise Exception('Unexpected payload')
raise UnexpectedException('Unexpected payload')
if not _verify_signature_str(compute_resource_payload, compute_resource_id, compute_resource_signature):
raise Exception('Invalid signature')
raise InvalidSignatureException('Invalid signature')
5 changes: 4 additions & 1 deletion api_helpers/routers/gui/_authenticate_gui_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ async def _authenticate_gui_request(github_access_token: str):
}
return user_id

class AuthException(Exception):
pass

async def _get_user_id_for_access_token(github_access_token: str):
url = 'https://api.github.com/user'
headers = {
Expand All @@ -31,6 +34,6 @@ async def _get_user_id_for_access_token(github_access_token: str):
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
raise Exception(f'Error getting user ID from github access token: {response.status}')
raise AuthException(f'Error getting user ID from github access token: {response.status}')
data = await response.json()
return data['login']
Loading

0 comments on commit 4020e60

Please sign in to comment.