Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make EmalRep optional #219

Merged
merged 2 commits into from Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 17 additions & 17 deletions backend/api/endpoints/analyze.py
Expand Up @@ -3,7 +3,7 @@
from pydantic import ValidationError
from redis import Redis

from backend import clients, deps, schemas, settings
from backend import clients, dependencies, schemas, settings
from backend.factories.response import ResponseFactory

router = APIRouter()
Expand All @@ -13,7 +13,7 @@ async def _analyze(
file: bytes,
*,
spam_assassin: clients.SpamAssassin,
email_rep: clients.EmailRep,
optional_email_rep: clients.EmailRep | None = None,
optional_inquest: clients.InQuest | None = None,
optional_vt: clients.VirusTotal | None = None,
optional_urlscan: clients.UrlScan | None = None,
Expand All @@ -28,7 +28,7 @@ async def _analyze(

return await ResponseFactory.call(
payload.file,
email_rep=email_rep,
optional_email_rep=optional_email_rep,
spam_assassin=spam_assassin,
optional_inquest=optional_inquest,
optional_urlscan=optional_urlscan,
Expand Down Expand Up @@ -56,17 +56,17 @@ async def analyze(
payload: schemas.Payload,
*,
background_tasks: BackgroundTasks,
optional_redis: deps.OptionalRedis,
spam_assassin: deps.SpamAssassin,
email_rep: deps.EmailRep,
optional_inquest: deps.OptionalInQuest,
optional_vt: deps.OptionalVirusTotal,
optional_urlscan: deps.OptionalUrlScan,
spam_assassin: dependencies.SpamAssassin,
optional_redis: dependencies.OptionalRedis,
optional_email_rep: dependencies.OptionalEmailRep,
optional_inquest: dependencies.OptionalInQuest,
optional_vt: dependencies.OptionalVirusTotal,
optional_urlscan: dependencies.OptionalUrlScan,
) -> schemas.Response:
response = await _analyze(
payload.file.encode(),
email_rep=email_rep,
spam_assassin=spam_assassin,
optional_email_rep=optional_email_rep,
optional_inquest=optional_inquest,
optional_urlscan=optional_urlscan,
optional_vt=optional_vt,
Expand All @@ -90,16 +90,16 @@ async def analyze_file(
file: bytes = File(...),
*,
background_tasks: BackgroundTasks,
optional_redis: deps.OptionalRedis,
spam_assassin: deps.SpamAssassin,
email_rep: deps.EmailRep,
optional_inquest: deps.OptionalInQuest,
optional_vt: deps.OptionalVirusTotal,
optional_urlscan: deps.OptionalUrlScan,
optional_redis: dependencies.OptionalRedis,
spam_assassin: dependencies.SpamAssassin,
optional_email_rep: dependencies.OptionalEmailRep,
optional_inquest: dependencies.OptionalInQuest,
optional_vt: dependencies.OptionalVirusTotal,
optional_urlscan: dependencies.OptionalUrlScan,
) -> schemas.Response:
response = await _analyze(
file,
email_rep=email_rep,
optional_email_rep=optional_email_rep,
spam_assassin=spam_assassin,
optional_inquest=optional_inquest,
optional_urlscan=optional_urlscan,
Expand Down
4 changes: 2 additions & 2 deletions backend/api/endpoints/cache.py
@@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, status

from backend import deps, settings
from backend import dependencies, settings

router = APIRouter()

Expand All @@ -11,7 +11,7 @@
summary="Get analysis cache keys",
description="Try to get analysis cache keys",
)
async def cache_keys(optional_redis: deps.OptionalRedis) -> list[str]:
async def cache_keys(optional_redis: dependencies.OptionalRedis) -> list[str]:
if optional_redis is None:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
Expand Down
6 changes: 4 additions & 2 deletions backend/api/endpoints/lookup.py
@@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, status

from backend import deps, schemas, settings
from backend import dependencies, schemas, settings

router = APIRouter()

Expand All @@ -11,7 +11,9 @@
summary="Lookup cached analysis",
description="Try to fetch existing analysis from database",
)
async def lookup(id: str, *, optional_redis: deps.OptionalRedis) -> schemas.Response:
async def lookup(
id: str, *, optional_redis: dependencies.OptionalRedis
) -> schemas.Response:
if optional_redis is None:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
Expand Down
6 changes: 3 additions & 3 deletions backend/api/endpoints/submit.py
@@ -1,7 +1,7 @@
import httpx
from fastapi import APIRouter, HTTPException, status

from backend import deps, schemas
from backend import dependencies, schemas
from backend.schemas.eml import Attachment
from backend.utils import attachment_to_file

Expand All @@ -16,7 +16,7 @@
status_code=200,
)
async def submit_to_inquest(
attachment: Attachment, *, optional_inquest: deps.OptionalInQuest
attachment: Attachment, *, optional_inquest: dependencies.OptionalInQuest
) -> schemas.SubmissionResult:
# check ext type
valid_types = ["doc", "docx", "ppt", "pptx", "xls", "xlsx"]
Expand Down Expand Up @@ -49,7 +49,7 @@ async def submit_to_inquest(
status_code=200,
)
async def submit_to_virustotal(
attachment: Attachment, *, optional_vt: deps.OptionalVirusTotal
attachment: Attachment, *, optional_vt: dependencies.OptionalVirusTotal
) -> schemas.SubmissionResult:
if optional_vt is None:
raise HTTPException(
Expand Down
8 changes: 6 additions & 2 deletions backend/clients/emailrep.py
@@ -1,11 +1,15 @@
import httpx
from starlette.datastructures import Secret

from backend import schemas


class EmailRep(httpx.AsyncClient):
def __init__(self) -> None:
super().__init__(base_url="https://emailrep.io")
def __init__(self, api_key: Secret) -> None:
super().__init__(
base_url="https://emailrep.io",
headers={"key": str(api_key), "user-agent": "EML-Analyzer"},
)

async def lookup(self, email: str) -> schemas.EmailRepLookup:
r = await self.get(f"/{email}")
Expand Down
15 changes: 9 additions & 6 deletions backend/deps.py → backend/dependencies.py
Expand Up @@ -71,13 +71,16 @@ async def get_optional_urlscan():


@asynccontextmanager
async def _get_email_rep():
async with clients.EmailRep() as client:
yield client
async def _get_optional_email_rep(api_key: Secret | None = settings.EMAIL_REP_API_KEY):
if api_key is None:
yield None
else:
async with clients.EmailRep(api_key=api_key) as client:
yield client


async def get_email_rep():
async with _get_email_rep() as client:
async def get_optional_email_rep():
async with _get_optional_email_rep(settings.EMAIL_REP_API_KEY) as client:
yield client


Expand All @@ -101,5 +104,5 @@ def get_spam_assassin() -> clients.SpamAssassin:
clients.UrlScan | None, Depends(get_optional_urlscan)
]

EmailRep = typing.Annotated[clients.EmailRep, Depends(get_email_rep)]
OptionalEmailRep = typing.Annotated[clients.EmailRep, Depends(get_optional_email_rep)]
SpamAssassin = typing.Annotated[clients.SpamAssassin, Depends(get_spam_assassin)]
10 changes: 5 additions & 5 deletions backend/factories/response.py
Expand Up @@ -75,8 +75,8 @@ async def set_verdicts(
response: schemas.Response,
*,
eml_file: bytes,
email_rep: clients.EmailRep,
spam_assassin: clients.SpamAssassin,
optional_email_rep: clients.EmailRep | None = None,
optional_vt: clients.VirusTotal | None = None,
optional_urlscan: clients.UrlScan | None = None,
optional_inquest: clients.InQuest | None = None,
Expand All @@ -86,9 +86,9 @@ async def set_verdicts(
get_oleid_verdict(response.eml.attachments),
]

if response.eml.header.from_ is not None:
if response.eml.header.from_ is not None and optional_email_rep is not None:
f_results.append(
get_email_rep_verdicts(response.eml.header.from_, client=email_rep)
get_email_rep_verdicts(response.eml.header.from_, client=optional_email_rep)
)

if optional_vt is not None:
Expand Down Expand Up @@ -117,8 +117,8 @@ async def call(
cls,
eml_file: bytes,
*,
email_rep: clients.EmailRep,
spam_assassin: clients.SpamAssassin,
optional_email_rep: clients.EmailRep | None,
optional_vt: clients.VirusTotal | None = None,
optional_urlscan: clients.UrlScan | None = None,
optional_inquest: clients.InQuest | None = None,
Expand All @@ -129,7 +129,7 @@ async def call(
partial(
set_verdicts,
eml_file=eml_file,
email_rep=email_rep,
optional_email_rep=optional_email_rep,
spam_assassin=spam_assassin,
optional_vt=optional_vt,
optional_urlscan=optional_urlscan,
Expand Down
3 changes: 3 additions & 0 deletions backend/settings.py
Expand Up @@ -35,6 +35,9 @@
)
INQUEST_API_KEY: Secret | None = config("INQUEST_API_KEY", cast=Secret, default=None)
URLSCAN_API_KEY: Secret | None = config("URLSCAN_API_KEY", cast=Secret, default=None)
EMAIL_REP_API_KEY: Secret | None = config(
"EMAIL_REP_API_KEY", cast=Secret, default=None
)

# Async/aiometer
ASYNC_MAX_AT_ONCE: int | None = config("ASYNC_MAX_AT_ONCE", cast=int, default=None)
Expand Down