Skip to content

Commit

Permalink
ref: hack around mypy's mistreatment of descriptors in unions (#53308)
Browse files Browse the repository at this point in the history
working around python/mypy#5570

this is especially important as more things are being moved to unions of
`HC | Model`
  • Loading branch information
asottile-sentry authored and armenzg committed Jul 24, 2023
1 parent c037c27 commit 3231cb3
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 40 deletions.
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ module = [
"sentry.api.bases.user",
"sentry.api.decorators",
"sentry.api.endpoints.accept_organization_invite",
"sentry.api.endpoints.accept_project_transfer",
"sentry.api.endpoints.artifact_lookup",
"sentry.api.endpoints.auth_config",
"sentry.api.endpoints.auth_login",
Expand Down Expand Up @@ -633,7 +632,6 @@ module = [
"sentry.models.authprovider",
"sentry.models.commit",
"sentry.models.group",
"sentry.models.groupassignee",
"sentry.models.grouphistory",
"sentry.models.groupowner",
"sentry.models.groupsnooze",
Expand Down Expand Up @@ -718,7 +716,6 @@ module = [
"sentry.projectoptions.defaults",
"sentry.queue.command",
"sentry.quotas.redis",
"sentry.ratelimits.utils",
"sentry.receivers.outbox",
"sentry.receivers.outbox.control",
"sentry.receivers.releases",
Expand Down Expand Up @@ -795,7 +792,6 @@ module = [
"sentry.services.hybrid_cloud.actor",
"sentry.services.hybrid_cloud.auth.impl",
"sentry.services.hybrid_cloud.integration.impl",
"sentry.services.hybrid_cloud.integration.service",
"sentry.services.hybrid_cloud.log.impl",
"sentry.services.hybrid_cloud.notifications.impl",
"sentry.services.hybrid_cloud.organizationmember_mapping.impl",
Expand Down Expand Up @@ -961,7 +957,6 @@ module = [
"sentry.web.frontend.project_event",
"sentry.web.frontend.react_page",
"sentry.web.frontend.reactivate_account",
"sentry.web.frontend.restore_organization",
"sentry.web.frontend.setup_wizard",
"sentry.web.frontend.shared_group_details",
"sentry.web.frontend.twofactor",
Expand Down
2 changes: 1 addition & 1 deletion src/sentry/api/bases/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def convert_args(
raise ResourceDoesNotExist

organization_context = organization_service.get_organization_by_slug(
slug=organization_slug, only_visible=False, user_id=request.user.id # type: ignore
slug=organization_slug, only_visible=False, user_id=request.user.id
)
if organization_context is None:
raise ResourceDoesNotExist
Expand Down
2 changes: 1 addition & 1 deletion src/sentry/integrations/discord/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_client(self) -> DiscordClient:
org_integration_id = self.org_integration.id if self.org_integration else None

return DiscordClient(
integration_id=self.model.id, # type:ignore
integration_id=self.model.id,
org_integration_id=org_integration_id,
)

Expand Down
2 changes: 1 addition & 1 deletion src/sentry/models/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def create_group_activity(
send_notification: bool = True,
) -> Activity:
if user:
user_id = user.id # type: ignore[assignment]
user_id = user.id
activity_args = {
"project_id": group.project_id,
"group": group,
Expand Down
8 changes: 4 additions & 4 deletions src/sentry/relay/config/metric_extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union

from sentry import features
from sentry.api.endpoints.project_transaction_threshold import DEFAULT_THRESHOLD
Expand Down Expand Up @@ -198,10 +198,10 @@ def _threshold_to_rules(
"inner": [
{
"op": "gt",
"name": _TRANSACTION_METRICS_TO_RULE_FIELD[cast(int, threshold.metric)],
"name": _TRANSACTION_METRICS_TO_RULE_FIELD[threshold.metric],
# The frustration threshold is always four times the threshold
# (see https://docs.sentry.io/product/performance/metrics/#apdex)
"value": cast(int, threshold.threshold) * 4,
"value": threshold.threshold * 4,
},
*extra_conditions,
],
Expand All @@ -216,7 +216,7 @@ def _threshold_to_rules(
"inner": [
{
"op": "gt",
"name": _TRANSACTION_METRICS_TO_RULE_FIELD[cast(int, threshold.metric)],
"name": _TRANSACTION_METRICS_TO_RULE_FIELD[threshold.metric],
"value": threshold.threshold,
},
*extra_conditions,
Expand Down
88 changes: 63 additions & 25 deletions tests/tools/mypy_helpers/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,32 @@
from __future__ import annotations

import pathlib
import os.path
import subprocess
import sys
from typing import Callable
import tempfile

import pytest

def call_mypy(src: str, *, plugins: list[str] | None = None) -> tuple[int, str]:
if plugins is None:
plugins = ["tools.mypy_helpers.plugin"]
with tempfile.TemporaryDirectory() as tmpdir:
cfg = os.path.join(tmpdir, "mypy.toml")
with open(cfg, "w") as f:
f.write(f"[tool.mypy]\nplugins = {plugins!r}\n")

@pytest.fixture
def call_mypy(tmp_path: pathlib.Path) -> Callable[[str], tuple[int, str]]:
cfg = """\
[tool.mypy]
plugins = ["tools.mypy_helpers.plugin"]
"""
cfg_path = tmp_path.joinpath("mypy.toml")
cfg_path.write_text(cfg)

def _call_mypy(contents: str) -> tuple[int, str]:
ret = subprocess.run(
(
*(sys.executable, "-m", "mypy"),
*("--config", str(cfg_path)),
*("-c", contents),
*("--config", cfg),
*("-c", src),
),
capture_output=True,
encoding="UTF-8",
)
return ret.returncode, ret.stdout

return _call_mypy


def test_invalid_get_connection_call(call_mypy):
def test_invalid_get_connection_call():
code = """
from django.db.transaction import get_connection
Expand All @@ -48,7 +42,7 @@ def test_invalid_get_connection_call(call_mypy):
assert out == expected


def test_ok_get_connection(call_mypy):
def test_ok_get_connection():
code = """
from django.db.transaction import get_connection
Expand All @@ -59,7 +53,7 @@ def test_ok_get_connection(call_mypy):
assert ret == 0


def test_invalid_transaction_atomic(call_mypy):
def test_invalid_transaction_atomic():
code = """
from django.db import transaction
Expand All @@ -78,7 +72,7 @@ def test_invalid_transaction_atomic(call_mypy):
assert out == expected


def test_ok_transaction_atomic(call_mypy):
def test_ok_transaction_atomic():
code = """
from django.db import transaction
Expand All @@ -89,7 +83,7 @@ def test_ok_transaction_atomic(call_mypy):
assert ret == 0


def test_ok_transaction_on_commit(call_mypy):
def test_ok_transaction_on_commit():
code = """
from django.db import transaction
Expand All @@ -102,7 +96,7 @@ def completed():
assert ret == 0


def test_invalid_transaction_on_commit(call_mypy):
def test_invalid_transaction_on_commit():
code = """
from django.db import transaction
Expand All @@ -120,7 +114,7 @@ def completed():
assert out == expected


def test_invalid_transaction_set_rollback(call_mypy):
def test_invalid_transaction_set_rollback():
code = """
from django.db import transaction
Expand All @@ -135,11 +129,55 @@ def test_invalid_transaction_set_rollback(call_mypy):
assert out == expected


def test_ok_transaction_set_rollback(call_mypy):
def test_ok_transaction_set_rollback():
code = """
from django.db import transaction
transaction.set_rollback(True, "default")
"""
ret, _ = call_mypy(code)
assert ret == 0


def test_field_descriptor_hack():
code = """\
from __future__ import annotations
from django.db import models
class M1(models.Model):
f: models.Field[int, int] = models.IntegerField()
class C:
f: int
def f(inst: C | M1 | M2) -> int:
return inst.f
# should also work with field subclasses
class F(models.Field[int, int]):
pass
class M2(models.Model):
f = F()
def g(inst: C | M2) -> int:
return inst.f
"""

# should be an error with default plugins
# mypy may fix this at some point hopefully: python/mypy#5570
ret, out = call_mypy(code, plugins=[])
assert ret
assert (
out
== """\
<string>:12: error: Incompatible return value type (got "Union[int, Field[int, int]]", expected "int") [return-value]
<string>:22: error: Incompatible return value type (got "Union[int, F]", expected "int") [return-value]
Found 2 errors in 1 file (checked 1 source file)
"""
)

# should be fixed with our special plugin
ret, _ = call_mypy(code)
assert ret == 0
38 changes: 36 additions & 2 deletions tools/mypy_helpers/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Callable

from mypy.nodes import ARG_POS
from mypy.plugin import FunctionSigContext, Plugin
from mypy.nodes import ARG_POS, TypeInfo
from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin
from mypy.types import CallableType, FunctionLike, Instance


Expand Down Expand Up @@ -48,12 +48,46 @@ def replace_transaction_atomic_sig_callback(ctx: FunctionSigContext) -> Callable
}


def field_descriptor_no_overloads(ctx: MethodSigContext) -> FunctionLike:
# ignore the class / non-model instance descriptor overloads
signature = ctx.default_signature
# replace `def __get__(self, inst: Model, owner: Any) -> _GT:`
# with `def __get__(self, inst: Any, owner: Any) -> _GT:`
if str(signature.arg_types[0]) == "django.db.models.base.Model":
return signature.copy_modified(arg_types=[signature.arg_types[1]] * 2)
else:
return signature


class SentryMypyPlugin(Plugin):
def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
return _FUNCTION_SIGNATURE_HOOKS.get(fullname)

def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
if fullname == "django.db.models.fields.Field":
return field_descriptor_no_overloads

clsname, _, methodname = fullname.rpartition(".")
if methodname != "__get__":
return None

clsinfo = self.lookup_fully_qualified(clsname)
if clsinfo is None or not isinstance(clsinfo.node, TypeInfo):
return None

fieldinfo = self.lookup_fully_qualified("django.db.models.fields.Field")
if fieldinfo is None:
return None

if fieldinfo.node in clsinfo.node.mro:
return field_descriptor_no_overloads
else:
return None


def plugin(version: str) -> type[SentryMypyPlugin]:
return SentryMypyPlugin
2 changes: 1 addition & 1 deletion tools/mypy_helpers/remove_unneeded_type_ignores.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def main() -> int:
cmd = (sys.executable, "-m", "tools.mypy_helpers.mypy_without_ignores", *sys.argv[1:])
out = subprocess.run(cmd, stdout=subprocess.PIPE)
for line in out.stdout.decode().splitlines():
if line.endswith('Unused "type: ignore" comment'):
if line.endswith("[unused-ignore]"):
fname, n, *_ = line.split(":")

subprocess.check_call(
Expand Down

0 comments on commit 3231cb3

Please sign in to comment.