Skip to content

Commit

Permalink
refactor: slimnize main
Browse files Browse the repository at this point in the history
  • Loading branch information
ninoseki committed Jul 9, 2023
1 parent e4e4bd9 commit 10ff14c
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 51 deletions.
49 changes: 4 additions & 45 deletions abuse_whois/main.py
@@ -1,64 +1,23 @@
import asyncio
import json
import socket
from contextlib import contextmanager

import typer
from asyncer import asyncify
from cachetools import TTLCache, cached

from . import schemas, settings
from . import schemas
from .errors import InvalidAddressError
from .matchers.shared_hosting import get_shared_hosting_provider
from .matchers.whois import get_whois_contact
from .matchers.whois import get_optional_whois_contact
from .utils import (
get_hostname,
get_registered_domain,
is_domain,
is_ip_address,
is_supported_address,
resolve,
)
from .whois import get_whois_record


@contextmanager
def with_socket_timeout(timeout: float):
old = socket.getdefaulttimeout()
try:
socket.setdefaulttimeout(timeout)
yield
except (socket.timeout, ValueError):
raise asyncio.TimeoutError(
f"{timeout} seconds have passed but there is no response"
)
finally:
socket.setdefaulttimeout(old)


@cached(
cache=TTLCache(
maxsize=settings.IP_ADDRESS_LOOKUP_CACHE_SIZE,
ttl=settings.IP_ADDRESS_LOOKUP_CACHE_TTL,
)
)
def _resolve(
hostname: str, *, timeout: float = float(settings.IP_ADDRESS_LOOKUP_TIMEOUT)
) -> str:
with with_socket_timeout(timeout):
ip = socket.gethostbyname(hostname)
return ip


resolve = asyncify(_resolve)


async def get_contact(domain_or_ip: str | None):
if domain_or_ip is None:
return None

return await get_whois_contact(domain_or_ip)


async def get_abuse_contacts(address: str) -> schemas.Contacts:
if not is_supported_address(address):
raise InvalidAddressError(f"{address} is not supported type address")
Expand Down Expand Up @@ -86,7 +45,7 @@ async def get_abuse_contacts(address: str) -> schemas.Contacts:
shared_hosting_provider = get_shared_hosting_provider(hostname)

registrar, hosting_provider = await asyncio.gather(
get_contact(domain), get_contact(ip_address)
get_optional_whois_contact(domain), get_optional_whois_contact(ip_address)
)

return schemas.Contacts(
Expand Down
7 changes: 7 additions & 0 deletions abuse_whois/matchers/whois/__init__.py
Expand Up @@ -66,3 +66,10 @@ async def get_whois_contact(
return None

return get_whois_abuse_contact(whois_record)


async def get_optional_whois_contact(hostname: str | None) -> Contact | None:
if hostname is None:
return None

return await get_whois_contact(hostname)
4 changes: 2 additions & 2 deletions abuse_whois/matchers/whois/rule.py
@@ -1,5 +1,5 @@
from abuse_whois.schemas import BaseRule
from abuse_whois.whois import get_whois_record
from ...schemas import BaseRule
from ...whois import get_whois_record


class WhoisRule(BaseRule):
Expand Down
5 changes: 2 additions & 3 deletions abuse_whois/matchers/whois/rules.py
@@ -1,9 +1,8 @@
import pathlib
from functools import lru_cache

from abuse_whois import settings
from abuse_whois.utils import glob_rules, load_yaml

from ... import settings
from ...utils import glob_rules, load_yaml
from .rule import WhoisRule

DEFAULT_RULE_DIRECTORY: pathlib.Path = pathlib.Path(__file__).parent / "./rules"
Expand Down
38 changes: 37 additions & 1 deletion abuse_whois/utils.py
@@ -1,15 +1,20 @@
import asyncio
import pathlib
import socket
from collections.abc import Callable
from contextlib import contextmanager
from functools import lru_cache
from typing import cast
from urllib.parse import urlparse

import tldextract
import validators
import yaml
from asyncer import asyncify
from cachetools import TTLCache, cached
from starlette.datastructures import CommaSeparatedStrings

from abuse_whois import settings
from . import settings


def _is_x(v: str, *, validator: Callable[[str], bool]) -> bool:
Expand Down Expand Up @@ -74,6 +79,37 @@ def get_hostname(value: str) -> str:
return parsed.hostname or value


@contextmanager
def with_socket_timeout(timeout: float):
old = socket.getdefaulttimeout()
try:
socket.setdefaulttimeout(timeout)
yield
except (socket.timeout, ValueError):
raise asyncio.TimeoutError(
f"{timeout} seconds have passed but there is no response"
)
finally:
socket.setdefaulttimeout(old)


@cached(
cache=TTLCache(
maxsize=settings.IP_ADDRESS_LOOKUP_CACHE_SIZE,
ttl=settings.IP_ADDRESS_LOOKUP_CACHE_TTL,
)
)
def _resolve(
hostname: str, *, timeout: float = float(settings.IP_ADDRESS_LOOKUP_TIMEOUT)
) -> str:
with with_socket_timeout(timeout):
ip = socket.gethostbyname(hostname)
return ip


resolve = asyncify(_resolve)


def load_yaml(path: str | pathlib.Path) -> dict:
with open(path) as f:
return cast(dict, yaml.safe_load(f))
Expand Down

0 comments on commit 10ff14c

Please sign in to comment.