Skip to content

Commit

Permalink
Merge pull request #57 from ninoseki/update-logic
Browse files Browse the repository at this point in the history
refactor: renew internal logics
  • Loading branch information
ninoseki committed Sep 12, 2023
2 parents 9f158e0 + 9c758c7 commit 85a0f47
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 30 deletions.
4 changes: 3 additions & 1 deletion abuse_whois/api/endpoints/whois.py
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, HTTPException, status

from abuse_whois import get_abuse_contacts, schemas
from abuse_whois.errors import InvalidAddressError
from abuse_whois.errors import InvalidAddressError, RateLimitError

router = APIRouter(prefix="/whois")

Expand All @@ -16,3 +16,5 @@ async def whois(query: schemas.Query) -> schemas.Contacts:
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except asyncio.TimeoutError as e:
raise HTTPException(status.HTTP_408_REQUEST_TIMEOUT, detail=str(e)) from e
except RateLimitError as e:
raise HTTPException(status.HTTP_429_TOO_MANY_REQUESTS, detail=str(e)) from e
27 changes: 18 additions & 9 deletions abuse_whois/main.py
Expand Up @@ -4,9 +4,12 @@
import typer

from . import schemas
from .errors import InvalidAddressError
from .errors import InvalidAddressError, RateLimitError
from .matchers.shared_hosting import get_shared_hosting_provider
from .matchers.whois import get_optional_whois_contact
from .matchers.whois import (
get_optional_whois_contact,
get_whois_contact_by_whois_record,
)
from .utils import (
get_hostname,
get_registered_domain,
Expand All @@ -24,12 +27,19 @@ async def get_abuse_contacts(address: str) -> schemas.Contacts:

hostname = get_hostname(address) # Domain or IP address

domain: str | None = None
try:
whois_record = await get_whois_record(hostname)
except asyncio.TimeoutError as e:
raise asyncio.TimeoutError(f"whois timeout for {hostname}") from e
except RateLimitError as e:
raise asyncio.TimeoutError(f"whois rate limit error for {hostname}") from e

ip_address: str | None = None
registered_domain: str | None = None
registrar: schemas.Contact | None = None

if is_domain(hostname):
domain = hostname
# set registered domain
registered_domain = get_registered_domain(hostname)

# get IP address by domain
Expand All @@ -38,15 +48,14 @@ async def get_abuse_contacts(address: str) -> schemas.Contacts:
except OSError:
pass

# get registrar contact
registrar = get_whois_contact_by_whois_record(whois_record)

if is_ip_address(hostname):
ip_address = hostname

whois_record = await get_whois_record(hostname)
shared_hosting_provider = get_shared_hosting_provider(hostname)

registrar, hosting_provider = await asyncio.gather(
get_optional_whois_contact(domain), get_optional_whois_contact(ip_address)
)
hosting_provider = await get_optional_whois_contact(ip_address)

return schemas.Contacts(
address=address,
Expand Down
18 changes: 13 additions & 5 deletions abuse_whois/matchers/whois/__init__.py
@@ -1,5 +1,7 @@
import asyncio
import re

from abuse_whois.errors import RateLimitError
from abuse_whois.schemas import Contact, WhoisRecord
from abuse_whois.utils import is_email
from abuse_whois.whois import get_whois_record
Expand Down Expand Up @@ -51,21 +53,27 @@ def get_whois_abuse_contact(record: WhoisRecord) -> Contact | None:
return Contact(provider=provider, address=email)


async def get_whois_contact(
hostname: str,
def get_whois_contact_by_whois_record(
whois_record: WhoisRecord,
) -> Contact | None:
rules = load_rules()
for rule in rules:
if await rule.match(hostname):
if rule.match(whois_record):
return rule.contact

return get_whois_abuse_contact(whois_record)


async def get_whois_contact(
hostname: str,
) -> Contact | None:
# Use whois registrar & abuse data
try:
whois_record = await get_whois_record(hostname)
except Exception:
except (asyncio.TimeoutError, RateLimitError):
return None

return get_whois_abuse_contact(whois_record)
return get_whois_contact_by_whois_record(whois_record)


async def get_optional_whois_contact(hostname: str | None) -> Contact | None:
Expand Down
8 changes: 3 additions & 5 deletions abuse_whois/matchers/whois/rule.py
@@ -1,9 +1,7 @@
from ...schemas import BaseRule
from ...whois import get_whois_record
from abuse_whois import schemas


class WhoisRule(BaseRule):
async def match(self, hostname: str) -> bool:
whois_record = await get_whois_record(hostname)
class WhoisRule(schemas.BaseRule):
def match(self, whois_record: schemas.WhoisRecord) -> bool:
data = whois_record.model_dump(by_alias=True)
return super().match(data)
4 changes: 1 addition & 3 deletions abuse_whois/schemas/contact.py
Expand Up @@ -40,6 +40,4 @@ class Contacts(APIModel):
registrar: Contact | None = Field(None, description="Registrar")
hosting_provider: Contact | None = Field(None, description="Hosting provider")

whois_record: WhoisRecord | None = Field(
None, description="Whois record of hostname"
)
whois_record: WhoisRecord = Field(description="Whois record of hostname")
21 changes: 14 additions & 7 deletions abuse_whois/whois.py
Expand Up @@ -8,7 +8,7 @@
from whois_parser import WhoisParser, WhoisRecord

from . import schemas, settings
from .errors import RateLimitError
from .errors import InvalidAddressError, RateLimitError
from .utils import get_registered_domain, is_domain, is_ip_address

whois_parser = WhoisParser()
Expand Down Expand Up @@ -40,16 +40,23 @@ async def get_whois_record(
hostname: str,
*,
timeout: int = settings.WHOIS_LOOKUP_TIMEOUT,
parser: WhoisParser = whois_parser
parser: WhoisParser = whois_parser,
) -> schemas.WhoisRecord:
if not is_ip_address(hostname):
hostname = get_registered_domain(hostname) or hostname

query_result = await query(hostname, timeout=timeout)
query_result = "\n".join(query_result.splitlines())
try:
query_result = await query(hostname, timeout=timeout)
query_result = "\n".join(query_result.splitlines())
parsed = parse(query_result, hostname, parser=parser)

parsed = parse(query_result, hostname, parser=parser)
if parsed.is_rate_limited:
raise RateLimitError()
if parsed.is_rate_limited:
raise RateLimitError(f"whois rate limit error for {hostname}")

if parsed.raw_text.startswith("No match for"):
raise InvalidAddressError(f"whois no match for {hostname}")

except asyncio.TimeoutError as e:
raise asyncio.TimeoutError(f"whois timeout for {hostname}") from e

return schemas.WhoisRecord.model_validate(asdict(parsed))

0 comments on commit 85a0f47

Please sign in to comment.