Skip to content

Commit

Permalink
refactor: asyncify query functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ninoseki committed Apr 23, 2022
1 parent e744922 commit 9ef65b1
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 37 deletions.
8 changes: 4 additions & 4 deletions abuse_whois/__init__.py
Expand Up @@ -17,7 +17,7 @@
__version__ = importlib_metadata.version(__name__)


def get_abuse_contacts(address: str) -> Contacts:
async def get_abuse_contacts(address: str) -> Contacts:
if not is_supported_address(address):
raise InvalidAddressError(f"{address} is not supported type address")

Expand All @@ -31,19 +31,19 @@ def get_abuse_contacts(address: str) -> Contacts:
shared_hosting_provider = get_shared_hosting_provider(hostname)

if is_domain(hostname):
registrar = get_contact_from_whois(hostname)
registrar = await get_contact_from_whois(hostname)

# get IP address by domain
try:
ip_address = resolve_ip_address(hostname)
ip_address = await resolve_ip_address(hostname)
except OSError:
pass

if is_ip_address(hostname):
ip_address = hostname

if ip_address is not None:
hosting_provider = get_contact_from_whois(ip_address)
hosting_provider = await get_contact_from_whois(ip_address)

return Contacts(
address=address,
Expand Down
4 changes: 2 additions & 2 deletions abuse_whois/api/endpoints/whois.py
Expand Up @@ -7,9 +7,9 @@


@router.post("/", response_model=schemas.Contacts)
def whois(query: schemas.Query):
async def whois(query: schemas.Query):
try:
return get_abuse_contacts(query.address)
return await get_abuse_contacts(query.address)
except InvalidAddressError as e:
raise HTTPException(400, detail=str(e))
except TimeoutError as e:
Expand Down
4 changes: 3 additions & 1 deletion abuse_whois/cli.py
@@ -1,5 +1,7 @@
import json
from functools import partial

import anyio
import typer

from . import get_abuse_contacts
Expand All @@ -13,7 +15,7 @@ def whois(
address: str = typer.Argument(..., help="URL, domain, IP address or email address")
):
try:
contacts = get_abuse_contacts(address)
contacts = anyio.run(partial(get_abuse_contacts, address))
print(contacts.json(by_alias=True)) # noqa: T001
except (InvalidAddressError, TimeoutError) as e:
print(json.dumps({"error": str(e)})) # noqa: T001
Expand Down
7 changes: 6 additions & 1 deletion abuse_whois/ip.py
@@ -1,6 +1,8 @@
import socket
from contextlib import contextmanager

from asyncer import asyncify

from . import settings
from .errors import TimeoutError

Expand All @@ -20,7 +22,10 @@ def socket_with_timeout(timeout: float):
socket.setdefaulttimeout(old_timeout)


def resolve_ip_address(hostname: str, *, timeout: int = settings.WHOIS_TIMEOUT) -> str:
def _resolve_ip_address(hostname: str, *, timeout: int = settings.WHOIS_TIMEOUT) -> str:
with socket_with_timeout(float(timeout)):
ip = socket.gethostbyname(hostname)
return ip


resolve_ip_address = asyncify(_resolve_ip_address)
6 changes: 3 additions & 3 deletions abuse_whois/matchers/whois/__init__.py
Expand Up @@ -55,17 +55,17 @@ def get_whois_abuse_contact(record: WhoisRecord) -> Optional[Contact]:
return Contact(provider=provider, address=email)


def get_contact_from_whois(
async def get_contact_from_whois(
hostname: str,
) -> Optional[Contact]:
rules = load_rules()
for rule in rules:
if rule.match(hostname):
if await rule.match(hostname):
return rule.contact

# Use whois registrar & abuse data
try:
whois_record = get_whois_record(hostname)
whois_record = await get_whois_record(hostname)
except Exception:
return None

Expand Down
4 changes: 2 additions & 2 deletions abuse_whois/matchers/whois/rule.py
Expand Up @@ -8,9 +8,9 @@ class WhoisRule(BaseRule):
contact: Contact
keywords: List[str]

def match(self, hostname: str) -> bool:
async def match(self, hostname: str) -> bool:
try:
whois_record = get_whois_record(hostname)
whois_record = await get_whois_record(hostname)
except Exception:
return False

Expand Down
4 changes: 1 addition & 3 deletions abuse_whois/schemas/rule.py
@@ -1,11 +1,9 @@
from typing import Optional

from .api_model import APIModel
from .contact import Contact


class BaseRule(APIModel):
contact: Contact

def match(self, hostname: str) -> Optional[Contact]:
async def match(self, hostname: str) -> bool:
raise NotImplementedError()
6 changes: 5 additions & 1 deletion abuse_whois/whois.py
Expand Up @@ -3,6 +3,7 @@
from typing import cast

import sh
from asyncer import asyncify
from cachetools import TTLCache, cached
from whois_parser import WhoisParser
from whois_parser.dataclasses import WhoisRecord
Expand Down Expand Up @@ -34,7 +35,7 @@ def get_whois_parser() -> WhoisParser:
maxsize=settings.WHOIS_RECORD_CACHE_SIZE, ttl=settings.WHOIS_RECORD_CACHE_TTL
)
)
def get_whois_record(
def _get_whois_record(
hostname: str, *, timeout: int = settings.WHOIS_TIMEOUT
) -> WhoisRecord:
if not is_ip_address(hostname):
Expand All @@ -52,3 +53,6 @@ def get_whois_record(

parser = get_whois_parser()
return parser.parse(whois_text, hostname=hostname)


get_whois_record = asyncify(_get_whois_record)
56 changes: 44 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -10,6 +10,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.7"
asyncer = "^0.0.1"
cachetools = "^5.0.0"
email-validator = "^1.1.3"
fastapi = "^0.75.2"
Expand Down
5 changes: 3 additions & 2 deletions tests/matchers/whois/test_whois.py
Expand Up @@ -10,5 +10,6 @@
"github.com",
],
)
def test_get_contact_from_whois(hostname: str):
assert get_contact_from_whois(hostname) is not None
@pytest.mark.asyncio
async def test_get_contact_from_whois(hostname: str):
assert await get_contact_from_whois(hostname) is not None
5 changes: 3 additions & 2 deletions tests/test_abuse_whois.py
Expand Up @@ -16,6 +16,7 @@ def test_version():
("foo@test.com", "test.com"),
],
)
def test_get_abuse_contacts(address: str, hostname: str):
contacts = get_abuse_contacts(address)
@pytest.mark.asyncio
async def test_get_abuse_contacts(address: str, hostname: str):
contacts = await get_abuse_contacts(address)
assert contacts.hostname == hostname
4 changes: 2 additions & 2 deletions tests/test_ip.py
@@ -1,9 +1,9 @@
import pytest

from abuse_whois.errors import TimeoutError
from abuse_whois.ip import resolve_ip_address
from abuse_whois.ip import _resolve_ip_address


def test_timeout_error():
with pytest.raises(TimeoutError):
assert resolve_ip_address("github.com", timeout=-1)
assert _resolve_ip_address("github.com", timeout=-1)
4 changes: 2 additions & 2 deletions tests/test_whois.py
@@ -1,9 +1,9 @@
import pytest

from abuse_whois.errors import TimeoutError
from abuse_whois.whois import get_whois_record
from abuse_whois.whois import _get_whois_record


def test_timeout_error():
with pytest.raises(TimeoutError):
assert get_whois_record("github.com", timeout=-1)
assert _get_whois_record("github.com", timeout=-1)

0 comments on commit 9ef65b1

Please sign in to comment.