diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile deleted file mode 100644 index 03f0227..0000000 --- a/.devcontainer/Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -# Base image for PyEasyHLA development. - -# This image is the basis of the project's devcontainer as well. - -ARG PYTHON_VERSION="3.13.3-bookworm" - -FROM python:${PYTHON_VERSION} AS base - -RUN apt update -y && apt upgrade -y - -# Install the vendored software. -ARG INSTANTCLIENT_BASIC="instantclient-basic-linux.x64-23.7.0.25.01.zip" -# The value of ORACLE_HOME depends on the instant client used as it will -# install to a path like ".../instantclient_23_7" where the version numbers -# will vary. -ARG ORACLE_HOME="/opt/oracle/instantclient_23_7" - -ENV ORACLE_HOME=${ORACLE_HOME} \ - LD_LIBRARY_PATH=${ORACLE_HOME}:$LD_LIBRARY_PATH - -COPY vendor/${INSTANTCLIENT_BASIC} /tmp/vendor/ -RUN unzip /tmp/vendor/${INSTANTCLIENT_BASIC} -d /opt/oracle &&\ - rm -rf /tmp/vendor diff --git a/.github/workflows/python.yml b/.github/workflows/test.yml similarity index 51% rename from .github/workflows/python.yml rename to .github/workflows/test.yml index 763c17d..b2c9537 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/test.yml @@ -22,10 +22,10 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.x"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -33,28 +33,29 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: - - apt update && apt install yamllint - - pip install uv + run: | + apt update && apt install yamllint + pip install uv - name: Check code - run: - - yamllint . - - uv run mypy --check . - - uv run ruff check . + continue-on-error: true + run: | + yamllint . + uv run mypy --check . + uv run ruff check . - name: Run tests run: uv run pytest --junitxml=pytest.xml - # TODO: Look into github actions, these are out of date - # - name: Upload coverage data - # uses: actions/upload-artifact@v3 - # with: - # name: coverage-data - # path: coverage.xml - - # - name: Publish Test Report - # uses: mikepenz/action-junit-report@v3 - # if: success() || failure() - # with: - # report_paths: unit_test.xml +# TODO: Look into github actions, these are out of date +# - name: Upload coverage data +# uses: actions/upload-artifact@v3 +# with: +# name: coverage-data +# path: coverage.xml + +# - name: Publish Test Report +# uses: mikepenz/action-junit-report@v3 +# if: success() || failure() +# with: +# report_paths: unit_test.xml diff --git a/.yamllint.yml b/.yamllint.yml index 8e49526..d2118fb 100644 --- a/.yamllint.yml +++ b/.yamllint.yml @@ -1,6 +1,7 @@ ignore: - .git/* - .venv/* + - src/easyhla/default_data/hla_standards.yaml extends: default diff --git a/pyproject.toml b/pyproject.toml index b056006..bf39676 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" [project] @@ -33,6 +33,9 @@ dependencies = [ "pyyaml>=6.0.2", "requests>=2.32.3", "typer>=0.15.2", + "types-pyyaml>=6.0.12.20250516", + "types-requests>=2.32.4.20250611", + "uv-dynamic-versioning>=0.8.2", ] [dependency-groups] @@ -54,9 +57,9 @@ dev = [ ] [project.urls] -Documentation = "https://github.com/unknown/easyhla#readme" -Issues = "https://github.com/unknown/easyhla/issues" -Source = "https://github.com/unknown/easyhla" +Documentation = "https://github.com/cfe-lab/pyeasyhla/blob/main/README.md" +Issues = "https://github.com/cfe-lab/pyeasyhla/issues" +Source = "https://github.com/cfe-lab/pyeasyhla" [project.scripts] clinical_hla = "easyhla.clinical_hla:main" @@ -72,18 +75,27 @@ database = [ ] [tool.hatch.version] -path = "src/easyhla/__about__.py" +source = "uv-dynamic-versioning" [tool.hatch.build] include = [ - "src/easyhla/*.py", - "src/easyhla/default_data/*.csv", - "src/easyhla/default_data/hla_nuc.fasta.mtime", + "src/easyhla/__about__.py", + "src/easyhla/__init__.py", + "src/easyhla/__main__.py", + "src/easyhla/easyhla.py", + "src/easyhla/interpret_from_json_lib.py", + "src/easyhla/interpret_from_json.py", + "src/easyhla/models.py", + "src/easyhla/py.typed", + "src/easyhla/update_alleles.py", + "src/easyhla/update_frequency_file_lib.py", + "src/easyhla/update_frequency_file.py", + "src/easyhla/utils.py", + "src/easyhla/default_data/hla_standards.yaml", + "src/easyhla/default_data/hla_frequencies.csv", ] exclude = [ - "tools", - "tests/output", - "tests/input", + "tests", ] skip-excluded-dirs = true directory = "output" @@ -91,9 +103,20 @@ directory = "output" [tool.hatch.build.targets.wheel] packages = ["src/easyhla"] +[tool.hatch.build.hooks.version] +path = "src/easyhla/_version.py" +template = ''' +__version__ = "{version}" +''' + [tool.uv] package = true +[tool.uv-dynamic-versioning] +vcs = "git" +style = "semver" +fallback-version = "0.0.0" + [tool.pytest.ini_options] pythonpath = "src" minversion = "6.0" @@ -147,3 +170,4 @@ match = "src/**/*.py" [tool.mypy] plugins = ["numpy.typing.mypy_plugin"] ignore_missing_imports = true +exclude = ["scripts/"] diff --git a/src/easyhla/__about__.py b/src/easyhla/__about__.py deleted file mode 100644 index e728ace..0000000 --- a/src/easyhla/__about__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.0-dev" diff --git a/src/easyhla/bblab.py b/src/easyhla/bblab.py index d244591..1d3be51 100644 --- a/src/easyhla/bblab.py +++ b/src/easyhla/bblab.py @@ -5,8 +5,9 @@ from pathlib import Path from typing import Any, Optional -import Bio import typer +from Bio.Seq import MutableSeq, Seq +from Bio.SeqIO import parse from .bblab_lib import ( EXON_AND_OTHER_EXON, @@ -14,8 +15,9 @@ HLAMismatchRow, pair_exons, ) -from .easyhla import DATE_FORMAT, EXON_NAME, EasyHLA +from .easyhla import DATE_FORMAT, EasyHLA from .models import HLAInterpretation, HLASequence +from .utils import EXON_NAME logger = logging.Logger(__name__, logging.ERROR) @@ -49,21 +51,21 @@ def log_and_print( def report_unmatched_sequences( - unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]], + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]], to_stdout: bool = False, ) -> None: """ Report exon sequences that did not have a matching exon. :param unmatched: unmatched exon sequences, grouped by which exon they represent - :type unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]] + :type unmatched: dict[EXON_NAME, dict[str, Seq]] :param to_stdout: ..., defaults to None :type to_stdout: Optional[bool], optional """ for exon, other_exon in EXON_AND_OTHER_EXON: - for entry in unmatched[exon]: + for sequence_id in unmatched[exon].keys(): log_and_print( - f"No matching {other_exon} for {entry.description}", + f"No matching {other_exon} for {sequence_id}", to_stdout=to_stdout, ) @@ -79,6 +81,8 @@ def process_from_file_to_files( ): if threshold and threshold < 0: raise RuntimeError("Threshold must be >=0 or None!") + elif threshold is None: + threshold = 0 rows: list[HLAInterpretationRow] = [] mismatch_rows: list[HLAMismatchRow] = [] @@ -93,13 +97,13 @@ def process_from_file_to_files( ) matched_sequences: list[HLASequence] - unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]] + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]] with open(filename, "r", encoding="utf-8") as f: matched_sequences, unmatched = pair_exons( - Bio.SeqIO.parse(f, "fasta"), + parse(f, "fasta"), locus.value, - list(hla_alg.standards.values())[0], + list(hla_alg.hla_standards[locus.value].values())[0], ) for hla_sequence in matched_sequences: @@ -133,10 +137,10 @@ def process_from_file_to_files( row: HLAInterpretationRow = HLAInterpretationRow.summary_row(result) rows.append(row) - mismatch_rows.extend(result.mismatch_rows()) + mismatch_rows.extend(HLAMismatchRow.mismatch_rows(result)) npats += 1 - nseqs += hla_sequence.num_seqs + nseqs += hla_sequence.num_sequences_used report_unmatched_sequences(unmatched, to_stdout=to_stdout) @@ -171,11 +175,11 @@ def process_from_file_to_files( ), ) mismatch_csv.writeheader() - mismatch_csv.writerows([dict[row] for row in mismatch_rows]) + mismatch_csv.writerows([dict(row) for row in mismatch_rows]) log_and_print( f"{npats} patients, {nseqs} sequences processed.", - log_level=logger.INFO, + log_level=logging.INFO, to_stdout=to_stdout, ) diff --git a/src/easyhla/bblab_lib.py b/src/easyhla/bblab_lib.py index 62177b0..593153f 100644 --- a/src/easyhla/bblab_lib.py +++ b/src/easyhla/bblab_lib.py @@ -3,7 +3,7 @@ from typing import TypedDict import numpy as np -from Bio.Seq import Seq +from Bio.Seq import MutableSeq, Seq from Bio.SeqIO import SeqRecord from pydantic import BaseModel @@ -36,7 +36,7 @@ def pair_exons_helper( sequence_record: SeqRecord, - unmatched: dict[EXON_NAME, dict[str, Seq]], + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]], ) -> tuple[str, bool, bool, str, str]: """ Helper that attempts to match the given sequence with a "partner" exon. @@ -55,7 +55,7 @@ def pair_exons_helper( - exon3 sequence """ # The `id`` field is expected to hold the sample name. - samp: str = sequence_record.id + samp: str = sequence_record.id or "" is_exon: bool = False matched: bool = False exon2: str = "" @@ -98,7 +98,7 @@ def pair_exons( sequence_records: Iterable[SeqRecord], locus: HLA_LOCUS, example_standard: HLAStandard, -) -> tuple[list[HLASequence], dict[EXON_NAME, dict[str, Seq]]]: +) -> tuple[list[HLASequence], dict[EXON_NAME, dict[str, Seq | MutableSeq | None]]]: """ Pair exons in the given input sequences. @@ -109,7 +109,7 @@ def pair_exons( sequences and attempt to match them up. """ matched_sequences: list[HLASequence] = [] - unmatched: dict[EXON_NAME, dict[str, Seq]] = { + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]] = { "exon2": {}, "exon3": {}, } @@ -118,7 +118,7 @@ def pair_exons( # Skip over any sequences that aren't the right length or contain # bad bases. try: - check_length(locus, str(sr.seq), sr.id) + check_length(locus, str(sr.seq), sr.id or "") except BadLengthException: continue @@ -147,21 +147,21 @@ def pair_exons( exon3_bin = pad_short(example_standard.sequence, nuc2bin(exon3), "exon3") matched_sequences.append( HLASequence( - two=(int(x) for x in exon2_bin), + two=tuple(int(x) for x in exon2_bin), intron=(), - three=(int(x) for x in exon3_bin), + three=tuple(int(x) for x in exon3_bin), name=identifier, locus=locus, num_sequences_used=2, ) ) else: - seq_numpy: np.array = pad_short( + seq_numpy: np.ndarray = pad_short( example_standard.sequence, nuc2bin(sr.seq), # type: ignore None, ) - seq: tuple[int] = tuple(int(x) for x in seq_numpy) + seq: tuple[int, ...] = tuple(int(x) for x in seq_numpy) matched_sequences.append( HLASequence( two=seq[:EXON2_LENGTH], diff --git a/src/easyhla/clinical_hla.py b/src/easyhla/clinical_hla.py index 23d2343..f289c4f 100644 --- a/src/easyhla/clinical_hla.py +++ b/src/easyhla/clinical_hla.py @@ -6,7 +6,7 @@ import logging import os from datetime import datetime -from typing import Final, Optional, TypedDict +from typing import Final, Literal, Optional, TypedDict, cast from sqlalchemy import create_engine, event from sqlalchemy.engine import Engine @@ -36,38 +36,15 @@ ) # Database connection parameters: -HLA_DB_USER: Final[str] = os.environ.get("HLA_DB_USER") -HLA_DB_PASSWORD: Final[str] = os.environ.get("HLA_DB_PASSWORD") +HLA_DB_USER: Final[Optional[str]] = os.environ.get("HLA_DB_USER") +HLA_DB_PASSWORD: Final[Optional[str]] = os.environ.get("HLA_DB_PASSWORD") HLA_DB_HOST: Final[str] = os.environ.get("HLA_DB_HOST", "192.168.67.7") -HLA_DB_PORT: Final[int] = os.environ.get("HLA_DB_PORT", 1521) +HLA_DB_PORT: Final[int] = int(os.environ.get("HLA_DB_PORT", 1521)) HLA_DB_SERVICE_NAME: Final[str] = os.environ.get("HLA_DB_SERVICE_NAME", "cfe") -HLA_ORACLE_LIB_PATH: Final[str] = os.environ.get("HLA_ORACLE_LIB_PATH") - -# These are the "configuration files" that the algorithm uses; these are or may -# be updated, in which case you specify the path to the new version in the -# environment. -HLA_STANDARDS: Final[str] = os.environ.get("HLA_STANDARDS") -HLA_FREQUENCIES: Final[str] = os.environ.get("HLA_FREQUENCIES") - - -def prepare_interpretation_for_serialization( - interpretation: HLAInterpretation, - locus: HLA_LOCUS, - processing_datetime: datetime, -) -> HLASequenceA | HLASequenceB | HLASequenceC: - """ - Prepare an HLA interpretation for output. - """ - if locus == "A": - return HLASequenceA.build_from_interpretation( - interpretation, processing_datetime - ) - elif locus == "B": - return HLASequenceB.build_from_interpretation( - interpretation, processing_datetime - ) - return HLASequenceC.build_from_interpretation(interpretation, processing_datetime) +HLA_ORACLE_LIB_PATH: Final[str] = os.environ.get( + "HLA_ORACLE_LIB_PATH", "/opt/oracle/instant_client" +) class SequencesByLocus(TypedDict): @@ -91,10 +68,10 @@ def interpret_sequences( def clinical_hla_driver( input_dir: str, + hla_a_results: str, + hla_b_results: str, + hla_c_results: str, db_engine: Optional[Engine] = None, - hla_a_results: Optional[str] = None, - hla_b_results: Optional[str] = None, - hla_c_results: Optional[str] = None, standards_path: Optional[str] = None, frequencies_path: Optional[str] = None, ) -> None: @@ -105,7 +82,8 @@ def clinical_hla_driver( "C": [], } for locus in ("B", "C"): - sequences[locus] = read_bc_sequences(input_dir, locus, logger) + b_or_c: Literal["B", "C"] = cast(Literal["B", "C"], locus) + sequences[b_or_c] = read_bc_sequences(input_dir, b_or_c, logger) # Perform interpretations: interpretations: dict[HLA_LOCUS, list[HLAInterpretation]] = { @@ -116,7 +94,9 @@ def clinical_hla_driver( processing_datetime: datetime = datetime.now() easyhla: EasyHLA = EasyHLA.use_config(standards_path, frequencies_path) for locus in ("A", "B", "C"): - interpretations[locus] = interpret_sequences(easyhla, sequences[locus]) + interpretations[cast(HLA_LOCUS, locus)] = interpret_sequences( + easyhla, sequences[cast(HLA_LOCUS, locus)] + ) # Prepare the interpretations for output: seqs_for_db: SequencesByLocus = { @@ -124,17 +104,20 @@ def clinical_hla_driver( "B": [], "C": [], } - for locus in ("A", "B", "C"): - # Each locus has a slightly different schema in the database, so we - # customize for each one. - for interp in interpretations[locus]: - seqs_for_db[locus].append( - prepare_interpretation_for_serialization( - interp, - locus, - processing_datetime, - ) - ) + # This next bit looks repetitive but mypy didn't like my solution for doing + # this in a loop (because each one is a different type). + for interp in interpretations["A"]: + seqs_for_db["A"].append( + HLASequenceA.build_from_interpretation(interp, processing_datetime) + ) + for interp in interpretations["B"]: + seqs_for_db["B"].append( + HLASequenceB.build_from_interpretation(interp, processing_datetime) + ) + for interp in interpretations["C"]: + seqs_for_db["C"].append( + HLASequenceC.build_from_interpretation(interp, processing_datetime) + ) # First, write to the output files: output_files: dict[HLA_LOCUS, str] = { @@ -148,19 +131,23 @@ def clinical_hla_driver( "C": HLASequenceC.CSV_HEADER, } for locus in ("A", "B", "C"): - if len(seqs_for_db[locus]) > 0: - with open(output_files[locus], "w") as f: + if len(seqs_for_db[cast(HLA_LOCUS, locus)]) > 0: + with open(output_files[cast(HLA_LOCUS, locus)], "w") as f: result_csv: csv.DictWriter = csv.DictWriter( - f, fieldnames=csv_headers[locus], extrasaction="ignore" + f, + fieldnames=csv_headers[cast(HLA_LOCUS, locus)], + extrasaction="ignore", ) result_csv.writeheader() - result_csv.writerows(dataclasses.asdict(x) for x in seqs_for_db[locus]) + result_csv.writerows( + dataclasses.asdict(x) for x in seqs_for_db[cast(HLA_LOCUS, locus)] + ) # Finally, write to the DB. if db_engine is not None: with Session(db_engine) as session: for locus in ("A", "B", "C"): - session.add_all(seqs_for_db[locus]) + session.add_all(seqs_for_db[cast(HLA_LOCUS, locus)]) session.commit() @@ -246,10 +233,10 @@ def schema_workaround(dbapi_connection, _): clinical_hla_driver( args.input_dir, - db_engine, args.hla_a_results, args.hla_b_results, args.hla_c_results, + db_engine, args.hla_standards, args.hla_frequencies, ) diff --git a/src/easyhla/clinical_hla_lib.py b/src/easyhla/clinical_hla_lib.py index e57576d..afe2eea 100644 --- a/src/easyhla/clinical_hla_lib.py +++ b/src/easyhla/clinical_hla_lib.py @@ -76,7 +76,7 @@ def get_common_serialization_fields( "alleles_all": ap.stringify(), "ambiguous": str(ap.is_ambiguous()), "homozygous": str(ap.is_homozygous()), - "mismatch_count": interpretation.lowest_mismatch_count(), + "mismatch_count": mismatch_count, "mismatches": mismatches_str, "enterdate": processing_datetime, } @@ -94,7 +94,7 @@ class HLASequenceA(HLADBBase): alleles_all: Mapped[Optional[str]] = mapped_column(String) ambiguous: Mapped[Optional[str]] = mapped_column(String) homozygous: Mapped[Optional[str]] = mapped_column(String) - mismatch_count: Mapped[Optional[str]] = mapped_column(Integer) + mismatch_count: Mapped[Optional[int]] = mapped_column(Integer) mismatches: Mapped[Optional[str]] = mapped_column(String) seq: Mapped[Optional[str]] = mapped_column(String) enterdate: Mapped[Optional[datetime]] = mapped_column(DateTime) @@ -140,7 +140,7 @@ class HLASequenceB(HLADBBase): alleles_all: Mapped[Optional[str]] = mapped_column(String) ambiguous: Mapped[Optional[str]] = mapped_column(String) homozygous: Mapped[Optional[str]] = mapped_column(String) - mismatch_count: Mapped[Optional[str]] = mapped_column(Integer) + mismatch_count: Mapped[Optional[int]] = mapped_column(Integer) mismatches: Mapped[Optional[str]] = mapped_column(String) b5701: Mapped[Optional[str]] = mapped_column(String) b5701_dist: Mapped[Optional[int]] = mapped_column(Integer) @@ -201,7 +201,7 @@ class HLASequenceC(HLADBBase): alleles_all: Mapped[Optional[str]] = mapped_column(String) ambiguous: Mapped[Optional[str]] = mapped_column(String) homozygous: Mapped[Optional[str]] = mapped_column(String) - mismatch_count: Mapped[Optional[str]] = mapped_column(Integer) + mismatch_count: Mapped[Optional[int]] = mapped_column(Integer) mismatches: Mapped[Optional[str]] = mapped_column(String) seqa: Mapped[Optional[str]] = mapped_column(String) seqb: Mapped[Optional[str]] = mapped_column(String) @@ -348,7 +348,7 @@ def identify_bc_sequence_files( if sample_match is None: logger.info(f'Skipping file "{filename}".') continue - sample_name: str = sample_match.group(1) + sample_name = sample_match.group(1) sample_exon: EXON_NAME = ( "exon2" if sample_match.group(2).upper() == "A" else "exon3" ) diff --git a/src/easyhla/easyhla.py b/src/easyhla/easyhla.py index dcabac8..519c8a3 100644 --- a/src/easyhla/easyhla.py +++ b/src/easyhla/easyhla.py @@ -4,7 +4,7 @@ from datetime import datetime from io import TextIOBase from operator import attrgetter -from typing import Final, Optional, TypedDict +from typing import Final, Optional, TypedDict, cast import numpy as np import yaml @@ -127,7 +127,7 @@ def read_hla_standards(standards_io: TextIOBase) -> LoadedStandards: } @staticmethod - def load_default_hla_standards() -> dict[str, HLAStandard]: + def load_default_hla_standards() -> LoadedStandards: """ Load HLA Standards from reference file. @@ -258,7 +258,7 @@ def combine_standards_stepper( # Keep track of matches we've already found: combos: dict[tuple[int, ...], int] = {} - current_rejection_threshold: int = float("inf") + current_rejection_threshold: int | float = float("inf") for std_ai, std_a in enumerate(matching_stds): if std_a.mismatch > current_rejection_threshold: continue @@ -269,8 +269,8 @@ def combine_standards_stepper( # "Mush" the two standards together to produce something # that looks like what you get when you sequence HLA. std_bin = np.array(std_b.sequence) | np.array(std_a.sequence) - allele_pair: tuple[str, str] = tuple( - sorted((std_a.allele, std_b.allele)) + allele_pair: tuple[str, str] = cast( + tuple[str, str], tuple(sorted((std_a.allele, std_b.allele))) ) # There could be more than one combined standard with the @@ -284,7 +284,7 @@ def combine_standards_stepper( else: seq_mask = np.full_like(std_bin, fill_value=15) # Note that seq is implicitly cast to a NumPy array: - mismatches: int = np.count_nonzero((std_bin ^ seq) & seq_mask != 0) + mismatches = np.count_nonzero((std_bin ^ seq) & seq_mask != 0) combos[combined_std_bin] = mismatches # cache this value if mismatches > current_rejection_threshold: @@ -330,7 +330,7 @@ def combine_standards( combos: dict[tuple[int, ...], tuple[int, list[tuple[str, str]]]] = {} - fewest_mismatches: int = float("inf") + fewest_mismatches: int | float = float("inf") for ( combined_std_bin, mismatches, @@ -346,7 +346,7 @@ def combine_standards( # criteria. result: dict[HLACombinedStandard, int] = {} - cutoff: int = max(fewest_mismatches, mismatch_threshold) + cutoff: int | float = max(fewest_mismatches, mismatch_threshold) for combined_std_bin, mismatch_count_and_pair_list in combos.items(): mismatch_count: int pair_list: list[tuple[str, str]] diff --git a/src/easyhla/interpret_from_json.py b/src/easyhla/interpret_from_json.py index d930659..c143b19 100644 --- a/src/easyhla/interpret_from_json.py +++ b/src/easyhla/interpret_from_json.py @@ -38,10 +38,7 @@ def main(): hla_input.hla_std_path, hla_input.hla_freq_path, ) - interp: HLAInterpretation = easyhla.interpret( - hla_input.hla_sequence(), - hla_input.locus, - ) + interp: HLAInterpretation = easyhla.interpret(hla_input.hla_sequence()) print(HLAResult.build_from_interpretation(interp).model_dump_json()) diff --git a/src/easyhla/interpret_from_json_lib.py b/src/easyhla/interpret_from_json_lib.py index a8a8472..c710fe0 100644 --- a/src/easyhla/interpret_from_json_lib.py +++ b/src/easyhla/interpret_from_json_lib.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from .__about__ import __version__ +from ._version import __version__ from .models import ( AllelePairs, HLACombinedStandard, @@ -68,7 +68,7 @@ def hla_sequence(self) -> HLASequence: exon3_str = self.seq1[-276:] else: exon2_str = self.seq1 - exon3_str = self.seq2 + exon3_str = self.seq2 or "" num_sequences_used: int = 1 if self.locus == "A" else 2 return HLASequence( diff --git a/src/easyhla/models.py b/src/easyhla/models.py index a5a3a7b..ed2a868 100644 --- a/src/easyhla/models.py +++ b/src/easyhla/models.py @@ -1,7 +1,7 @@ import re from collections.abc import Iterable from operator import itemgetter -from typing import ClassVar, Final, Optional, Self +from typing import Final, Optional import numpy as np from pydantic import BaseModel, ConfigDict @@ -61,7 +61,7 @@ def sequence_np(self) -> np.ndarray: return np.array(self.sequence) @classmethod - def from_raw_standard(cls, raw_standard: HLARawStandard) -> Self: + def from_raw_standard(cls, raw_standard: HLARawStandard) -> "HLAStandard": return cls( allele=raw_standard.allele, two=nuc2bin(raw_standard.exon2), @@ -134,8 +134,10 @@ def __lt__(self, other: "HLAProteinPair") -> bool: ) return me_tuple < other_tuple - UNMAPPED: ClassVar[Final[str]] = "unmapped" - DEPRECATED: ClassVar[Final[str]] = "deprecated" + # Note: originally these were annotated as ClassVar[Final[str]] but this + # isn't supported in versions of Python prior to 3.13. + UNMAPPED: Final[str] = "unmapped" + DEPRECATED: Final[str] = "deprecated" class NonAlleleException(Exception): def __init__( @@ -153,7 +155,7 @@ def __init__( @classmethod def from_frequency_entry( cls, raw_first_allele: str, raw_second_allele: str - ) -> Optional[Self]: + ) -> Optional["HLAProteinPair.NonAlleleException"]: first_unmapped: bool = False first_deprecated: bool = False second_unmapped: bool = False @@ -182,7 +184,7 @@ def from_frequency_entry( cls, raw_first_allele: str, raw_second_allele: str, - ) -> Self: + ) -> "HLAProteinPair": any_problems: Optional[HLAProteinPair.NonAlleleException] = ( HLAProteinPair.NonAlleleException.from_frequency_entry( raw_first_allele, raw_second_allele diff --git a/src/easyhla/update_alleles.py b/src/easyhla/update_alleles.py index 32416c9..85250c1 100644 --- a/src/easyhla/update_alleles.py +++ b/src/easyhla/update_alleles.py @@ -8,7 +8,7 @@ import time from datetime import datetime from io import StringIO -from typing import Final, Optional, TypedDict +from typing import Final, Optional, TypedDict, cast import Bio import requests @@ -169,7 +169,7 @@ def get_commit_hash( return None -def get_from_git(tag: str) -> tuple[str, datetime, str]: +def get_from_git(tag: str) -> tuple[str, datetime, Optional[str]]: alleles_str: str retrieval_datetime: datetime for i in range(5): @@ -185,7 +185,7 @@ def get_from_git(tag: str) -> tuple[str, datetime, str]: else: break - commit_hash: str + commit_hash: Optional[str] for i in range(5): try: commit_hash = get_commit_hash(tag) @@ -271,7 +271,7 @@ def main(): logger.info(f"Retrieving alleles from tag {args.tag}....") alleles_str: str retrieval_datetime: datetime - commit_hash: str + commit_hash: Optional[str] alleles_str, retrieval_datetime, commit_hash = get_from_git(args.tag) logger.info( f"Alleles (version {args.tag}, commit hash {commit_hash}) retrieved at " @@ -301,10 +301,12 @@ def main(): logger.info("Identifying identical HLA alleles....") standards_for_saving: StoredHLAStandards = StoredHLAStandards( tag=args.tag, - commit_hash=commit_hash, + commit_hash=commit_hash or "", last_updated=retrieval_datetime, standards={ - locus: group_identical_alleles(raw_standards[locus]) + cast(HLA_LOCUS, locus): group_identical_alleles( + raw_standards[cast(HLA_LOCUS, locus)] + ) for locus in ("A", "B", "C") }, ) diff --git a/src/easyhla/update_frequency_file.py b/src/easyhla/update_frequency_file.py index 2e21938..eecc417 100644 --- a/src/easyhla/update_frequency_file.py +++ b/src/easyhla/update_frequency_file.py @@ -49,8 +49,8 @@ def main(): with args.name_mapping: old_to_new: dict[OldName, NewName] deprecated: list[str] - deprecated_maps_to_other: list[tuple[str, str]] - mapping_overrides_deprecated: list[tuple[str, str]] + deprecated_maps_to_other: list[tuple[str, NewName]] + mapping_overrides_deprecated: list[tuple[str, NewName]] ( old_to_new, deprecated, diff --git a/src/easyhla/update_frequency_file_lib.py b/src/easyhla/update_frequency_file_lib.py index a3a1d6f..b559e6a 100644 --- a/src/easyhla/update_frequency_file_lib.py +++ b/src/easyhla/update_frequency_file_lib.py @@ -3,7 +3,7 @@ from collections import Counter from dataclasses import dataclass from io import TextIOBase -from typing import Final, Optional, Self, TypedDict +from typing import Final, Optional, TypedDict, cast from .easyhla import EasyHLA from .models import HLAProteinPair @@ -25,16 +25,17 @@ class OldName: field_2: str @classmethod - def from_string(cls, old_name_str: str) -> Self: + def from_string(cls, old_name_str: str) -> "OldName": """ Build an instance directly from an entry in the nomenclature mapping. The old names look like "A*010507N" for loci other than C; HLA-C old names look like "Cw*010203N". """ - locus: str = old_name_str[0] - if locus not in ("A", "B", "C"): + raw_locus: str = old_name_str[0] + if raw_locus not in ("A", "B", "C"): raise OtherLocusException() + locus: HLA_LOCUS = cast(HLA_LOCUS, raw_locus) raw_coordinates: str = old_name_str[2:] if locus == "C": @@ -45,7 +46,9 @@ def from_string(cls, old_name_str: str) -> Self: return cls(locus, field_1, field_2) @classmethod - def from_old_frequency_format(cls, locus: HLA_LOCUS, four_digit_code: str) -> Self: + def from_old_frequency_format( + cls, locus: HLA_LOCUS, four_digit_code: str + ) -> "OldName": """ Build an instance from an entry in the old frequency file format. @@ -70,7 +73,7 @@ class NewName: field_2: str @classmethod - def from_string(cls, new_name_str: str) -> Self: + def from_string(cls, new_name_str: str) -> "NewName": """ Build an instance directly from an entry in the nomenclature mapping. @@ -80,9 +83,10 @@ def from_string(cls, new_name_str: str) -> Self: return cls(None, "", "") coords: list[str] = new_name_str.split(":") - locus: str = coords[0][0] - if locus not in ("A", "B", "C"): + raw_locus: str = coords[0][0] + if raw_locus not in ("A", "B", "C"): raise OtherLocusException() + locus: HLA_LOCUS = cast(HLA_LOCUS, raw_locus) field_1: str = coords[0][2:] field_2_match: Optional[re.Match] = re.match(r"(\d+)[a-zA-Z]*", coords[1]) if field_2_match is None: @@ -129,8 +133,8 @@ def parse_nomenclature( remapping_lines: list[str] = remapping_str.split("\n")[2:-1] deprecated: list[str] = [] - deprecated_maps_to_other: list[tuple[str, str]] = [] - mapping_overrides_deprecated: list[tuple[str, str]] = [] + deprecated_maps_to_other: list[tuple[str, NewName]] = [] + mapping_overrides_deprecated: list[tuple[str, NewName]] = [] remapping: dict[OldName, NewName] = {} for remapping_line in remapping_lines: @@ -173,8 +177,9 @@ class FrequencyRowDict(TypedDict): c_second: str -FREQUENCY_FIELDS: Final[tuple[str, str, str, str, str, str]] = sum( - EasyHLA.FREQUENCY_LOCUS_COLUMNS.values(), () +FREQUENCY_FIELDS: Final[tuple[str, str, str, str, str, str]] = cast( + tuple[str, str, str, str, str, str], + sum(EasyHLA.FREQUENCY_LOCUS_COLUMNS.values(), ()), ) @@ -192,7 +197,7 @@ def update_old_frequencies( and Counters that represent the alleles that are unmapped in the new naming scheme and the alleles that are deprecated in the new naming scheme. """ - old_frequencies_csv: csv.reader = csv.reader(old_frequencies_file) + old_frequencies_csv = csv.reader(old_frequencies_file) # Report to the user any frequencies that are either unmapped or # deprecated. @@ -203,7 +208,7 @@ def update_old_frequencies( for row in old_frequencies_csv: loci: tuple[HLA_LOCUS, HLA_LOCUS, HLA_LOCUS] = ("A", "B", "C") - updated: FrequencyRowDict = {x: None for x in FREQUENCY_FIELDS} + updated: dict[str, str] = {} for idx in range(6): locus: HLA_LOCUS = loci[int(idx / 2)] column_name: str = FREQUENCY_FIELDS[idx] @@ -219,6 +224,15 @@ def update_old_frequencies( deprecated_alleles_seen[(locus, row[idx])] += 1 updated[column_name] = new_name_str - updated_frequencies.append(updated) + updated_frequencies.append( + FrequencyRowDict( + a_first=updated["a_first"], + a_second=updated["a_second"], + b_first=updated["b_first"], + b_second=updated["b_second"], + c_first=updated["c_first"], + c_second=updated["c_second"], + ) + ) return updated_frequencies, unmapped_alleles, deprecated_alleles_seen diff --git a/src/easyhla/utils.py b/src/easyhla/utils.py index c29b24e..e7a6d55 100644 --- a/src/easyhla/utils.py +++ b/src/easyhla/utils.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Iterable, Sequence from datetime import datetime -from typing import Final, Literal, Optional, Self +from typing import Final, Literal, Optional, cast import numpy as np from Bio.SeqIO import SeqRecord @@ -272,21 +272,21 @@ def pad_short( exon3_std_bin: np.ndarray = np.array(std_bin[-EXON3_LENGTH:]) if exon == "exon2": left_pad, right_pad = calc_padding( - exon2_std_bin, + cast(Sequence[int], exon2_std_bin), seq_bin, ) elif exon == "exon3": left_pad, right_pad = calc_padding( - exon3_std_bin, + cast(Sequence[int], exon3_std_bin), seq_bin, ) else: # i.e. this is a full sequence possibly with intron left_pad, _ = calc_padding( - exon2_std_bin, + cast(Sequence[int], exon2_std_bin), seq_bin[: int(EXON2_LENGTH / 2)], ) _, right_pad = calc_padding( - exon3_std_bin, + cast(Sequence[int], exon3_std_bin), seq_bin[-int(EXON3_LENGTH / 2) :], ) return np.concatenate( @@ -300,7 +300,7 @@ def pad_short( def get_acceptable_match( sequence: str, reference: str, mismatch_threshold: int = 20 -) -> tuple[int, Optional[str]]: +) -> tuple[int, str]: """ Get an "acceptable match" between the sequence and reference. @@ -316,7 +316,7 @@ def get_acceptable_match( raise ValueError("sequence must be at least as long as the reference") score: int = len(reference) - best_match: Optional[str] = None + best_match: str = sequence[0 : len(reference)] ref_np: np.ndarray = np.array(list(reference)) for shift in range(len(sequence) - len(reference) + 1): @@ -389,8 +389,12 @@ def collate_standards( checked to see if it has acceptable matches for both exon2 and exon3. """ output_status_updates: bool = False + actual_report_interval: int = 1000 + actual_logger: logging.Logger if logger is not None and report_interval is not None and report_interval > 0: output_status_updates = True + actual_report_interval = cast(int, report_interval) + actual_logger = cast(logging.Logger, logger) standards: dict[HLA_LOCUS, list[HLARawStandard]] = { "A": [], @@ -398,23 +402,23 @@ def collate_standards( "C": [], } for idx, allele_sr in enumerate(allele_srs, start=1): - if output_status_updates and idx % report_interval == 0: - logger.info(f"Processing sequence {idx} of {len(allele_srs)}....") + if output_status_updates and idx % actual_report_interval == 0: + actual_logger.info(f"Processing sequence {idx} of {len(allele_srs)}....") # The FASTA headers look like: # >HLA:HLA00001 A*01:01:01:01 1098 bp allele_name: str = allele_sr.description.split(" ")[1] - locus: HLA_LOCUS = allele_name[0] - - if locus not in ("A", "B", "C"): + raw_locus: str = allele_name[0] + if raw_locus not in ("A", "B", "C"): continue + locus: HLA_LOCUS = cast(HLA_LOCUS, raw_locus) - exon2_match: tuple[int, Optional[str]] = get_acceptable_match( + exon2_match: tuple[int, str] = get_acceptable_match( str(allele_sr.seq), exon_references[locus]["exon2"], mismatch_threshold=acceptable_match_search_threshold, ) - exon3_match: tuple[int, Optional[str]] = get_acceptable_match( + exon3_match: tuple[int, str] = get_acceptable_match( str(allele_sr.seq), exon_references[locus]["exon3"], mismatch_threshold=acceptable_match_search_threshold, @@ -431,7 +435,7 @@ def collate_standards( ) ) elif logger is not None: - logger.info( + actual_logger.info( f'Rejecting "{allele_name}": {exon2_match[0]} exon2 mismatches,' f" {exon3_match[0]} exon3 mismatches." ) @@ -447,7 +451,10 @@ class GroupedAllele(BaseModel): exon3: str alleles: list[str] - @computed_field + # Due to this issue: + # https://github.com/python/mypy/issues/1362 + # we need the special mypy instruction here. + @computed_field # type: ignore[misc] @property def name(self) -> str: """ @@ -542,7 +549,7 @@ class StoredHLAStandards(BaseModel): checksum: Optional[str] = None @model_validator(mode="after") - def compute_compare_checksum(self) -> Self: + def compute_compare_checksum(self) -> "StoredHLAStandards": checksum: str = compute_stored_standard_checksum( self.tag, self.commit_hash, diff --git a/tests/bblab_lib_test.py b/tests/bblab_lib_test.py index be4161a..d62cdd8 100644 --- a/tests/bblab_lib_test.py +++ b/tests/bblab_lib_test.py @@ -1,5 +1,5 @@ import pytest -from Bio.Seq import Seq +from Bio.Seq import MutableSeq, Seq from Bio.SeqIO import SeqRecord from easyhla.bblab_lib import ( @@ -352,7 +352,7 @@ ) def test_pair_exons_helper( sr: SeqRecord, - unmatched: dict[EXON_NAME, dict[str, Seq]], + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]], expected_id: str, expected_is_exon: bool, expected_matched: bool, @@ -730,7 +730,7 @@ def test_pair_exons( expected_unmatched: dict[EXON_NAME, dict[str, Seq]], ): paired_seqs: list[HLASequence] - unmatched: dict[EXON_NAME, dict[str, Seq]] + unmatched: dict[EXON_NAME, dict[str, Seq | MutableSeq | None]] current_standard: HLARawStandard = HLA_STANDARDS[locus] fake_standard: HLAStandard = HLAStandard( diff --git a/tests/clinical_hla_lib_test.py b/tests/clinical_hla_lib_test.py index 0068be5..42d2980 100644 --- a/tests/clinical_hla_lib_test.py +++ b/tests/clinical_hla_lib_test.py @@ -1,6 +1,7 @@ from datetime import datetime from pathlib import Path -from typing import Final +from typing import Final, Literal +from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture @@ -508,14 +509,14 @@ def test_sanitize_sequences_bad_length( ): with pytest.raises(BadLengthException) as e: sanitize_sequence(raw_contents, locus, sample_name) - assert e.expected_length == expected_length_str - assert e.actual_length == actual_length + assert e.value.expected_length == expected_length_str + assert e.value.actual_length == actual_length def test_sanitize_sequences_invalid_character(): raw_contents: str = "A" * 100 + "_" + "C" * 150 sample_name: str = "E12345_exon3_short" - with pytest.raises(InvalidBaseException) as e: + with pytest.raises(InvalidBaseException): sanitize_sequence(raw_contents, "C", sample_name) @@ -704,7 +705,7 @@ def test_read_a_sequences( dummy_path: Path = tmp_path / filename dummy_path.write_text(file_contents) - mock_logger: mocker.MagicMock = mocker.MagicMock() + mock_logger: MagicMock = mocker.MagicMock() result: list[HLASequence] = read_a_sequences(str(tmp_path), mock_logger) @@ -878,7 +879,7 @@ def test_read_a_sequences( ) def test_identify_bc_sequences( filenames: list[str], - locus: HLA_LOCUS, + locus: Literal["B", "C"], expected_result: dict[str, dict[EXON_NAME, str]], expected_logger_calls: list[str], tmp_path: Path, @@ -888,7 +889,7 @@ def test_identify_bc_sequences( dummy_path: Path = tmp_path / filename dummy_path.write_text("ACGT") - mock_logger: mocker.MagicMock = mocker.MagicMock() + mock_logger: MagicMock = mocker.MagicMock() result: dict[str, dict[EXON_NAME, str]] = identify_bc_sequence_files( str(tmp_path), locus, mock_logger @@ -1032,7 +1033,7 @@ def test_identify_bc_sequences( ) def test_read_bc_sequences( raw_sequences: dict[str, str], - locus: HLA_LOCUS, + locus: Literal["B", "C"], expected_sequences: list[HLASequence], expected_logger_calls: list[str], tmp_path: Path, @@ -1042,7 +1043,7 @@ def test_read_bc_sequences( dummy_path: Path = tmp_path / filename dummy_path.write_text(file_contents) - mock_logger: mocker.MagicMock = mocker.MagicMock() + mock_logger: MagicMock = mocker.MagicMock() result: list[HLASequence] = read_bc_sequences(str(tmp_path), locus, mock_logger) diff --git a/tests/conftest.py b/tests/conftest.py index 7704be1..7371831 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ from typing import Tuple from easyhla.easyhla import EasyHLA +from easyhla.utils import bin2nuc, nuc2bin -def make_comparison(easyhla: EasyHLA, ref_seq: str, test_seq: str) -> str: +def make_comparison(ref_seq: str, test_seq: str) -> str: """ Compares two sequences for differences @@ -16,7 +17,7 @@ def make_comparison(easyhla: EasyHLA, ref_seq: str, test_seq: str) -> str: :return: A sequence where mismatches are replaced with '_' :rtype: str """ - ref, test = easyhla.nuc2bin(ref_seq.strip()), easyhla.nuc2bin(test_seq.strip()) + ref, test = nuc2bin(ref_seq.strip()), nuc2bin(test_seq.strip()) masked_seq = [] for i in range(max(len(ref), len(test))): @@ -30,8 +31,8 @@ def make_comparison(easyhla: EasyHLA, ref_seq: str, test_seq: str) -> str: side_is_short = "test" elif len(ref) < len(test): side_is_short = "ref" - return easyhla.bin2nuc(masked_seq) + f" [{side_is_short} is short]" # type: ignore - return easyhla.bin2nuc(masked_seq) # type: ignore + return bin2nuc(masked_seq) + f" [{side_is_short} is short]" # type: ignore + return bin2nuc(masked_seq) # type: ignore def compare_ref_vs_test( @@ -74,7 +75,7 @@ def compare_ref_vs_test( "INTRON", "EXON3", ]: - comparison = make_comparison(easyhla, _ref, _test) + comparison = make_comparison(_ref, _test) if "_" in comparison: print( ">>>", diff --git a/tests/easyhla_test.py b/tests/easyhla_test.py index ab92496..1a3cb1e 100644 --- a/tests/easyhla_test.py +++ b/tests/easyhla_test.py @@ -1,9 +1,10 @@ import os -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from datetime import datetime from io import StringIO from pathlib import Path -from typing import Optional +from typing import Optional, cast +from unittest.mock import MagicMock, _Call import numpy as np import pytest @@ -110,7 +111,7 @@ def easyhla(): HLA_STANDARDS[locus] ) } - for locus in ("A", "B", "C") + for locus in cast(tuple[HLA_LOCUS, HLA_LOCUS, HLA_LOCUS], ("A", "B", "C")) } dummy_loaded_standards: LoadedStandards = { "tag": "v0.1.0-dummy-test", @@ -508,9 +509,9 @@ def easyhla(): ], ) def test_combine_standards_stepper( - sequence: Iterable[int], + sequence: Sequence[int], matching_standards: list[HLAStandardMatch], - thresholds: list[Optional[int]], + thresholds: list[int], exp_result: list[tuple[tuple[int, ...], int, tuple[str, str]]], ): for threshold in thresholds: @@ -1135,7 +1136,7 @@ def test_get_mismatches_good_cases( ): for locus in locuses: result: list[HLAMismatch] = EasyHLA.get_mismatches( - tuple(std_bin), np.array(seq_bin), locus + tuple(std_bin), cast(Sequence[int], np.array(seq_bin)), locus ) assert result == expected_result @@ -1176,7 +1177,11 @@ def test_get_mismatches_errors( ): for locus in ["A", "B", "C"]: with pytest.raises(ValueError) as excinfo: - EasyHLA.get_mismatches(tuple(std_bin), np.array(seq_bin), locus) + EasyHLA.get_mismatches( + tuple(std_bin), + cast(Sequence[int], np.array(seq_bin)), + cast(HLA_LOCUS, locus), + ) assert expected_error in str(excinfo.value) @@ -1429,11 +1434,11 @@ def test_interpret_good_cases( easyhla.hla_standards = standards # Spy on the internals to make sure they're called correctly. - get_matching_standards_spy: mocker.MagicMock = mocker.spy( + get_matching_standards_spy: MagicMock = mocker.spy( easyhla, "get_matching_standards" ) - combine_standards_spy: mocker.MagicMock = mocker.spy(easyhla, "combine_standards") - get_mismatches_spy: mocker.MagicMock = mocker.spy(easyhla, "get_mismatches") + combine_standards_spy: MagicMock = mocker.spy(easyhla, "combine_standards") + get_mismatches_spy: MagicMock = mocker.spy(easyhla, "get_mismatches") result: HLAInterpretation = easyhla.interpret(sequence, threshold=threshold) assert result == expected_interpretation @@ -1442,7 +1447,7 @@ def test_interpret_good_cases( # and the comparison fails; we have to manually convert the value to a list # to be able to compare them. get_matching_standards_spy.assert_called_once() - gms_call_args: mocker.call = get_matching_standards_spy.call_args + gms_call_args: _Call = get_matching_standards_spy.call_args assert len(gms_call_args.args) == 2 assert len(gms_call_args.kwargs) == 0 assert gms_call_args.args[0] == sequence.sequence_for_interpretation @@ -1518,11 +1523,11 @@ def test_interpret_error_cases( easyhla.hla_standards[locus] = {std.allele: std for std in raw_standards[locus]} # Spy on the internals to make sure they're called correctly. - get_matching_standards_spy: mocker.MagicMock = mocker.spy( + get_matching_standards_spy: MagicMock = mocker.spy( easyhla, "get_matching_standards" ) - combine_standards_spy: mocker.MagicMock = mocker.spy(easyhla, "combine_standards") - get_mismatches_spy: mocker.MagicMock = mocker.spy(easyhla, "get_mismatches") + combine_standards_spy: MagicMock = mocker.spy(easyhla, "combine_standards") + get_mismatches_spy: MagicMock = mocker.spy(easyhla, "get_mismatches") with pytest.raises(EasyHLA.NoMatchingStandards): easyhla.interpret(sequence, threshold=threshold) @@ -1531,7 +1536,7 @@ def test_interpret_error_cases( # and the comparison fails; we have to manually convert the value to a list # to be able to compare them. get_matching_standards_spy.assert_called_once() - gms_call_args: mocker.call = get_matching_standards_spy.call_args + gms_call_args: _Call = get_matching_standards_spy.call_args assert len(gms_call_args.args) == 2 assert len(gms_call_args.kwargs) == 0 assert gms_call_args.args[0] == sequence.sequence_for_interpretation @@ -1746,10 +1751,10 @@ def test_read_hla_standards( # Also try reading it from a file. p = tmp_path / "hla_standards.yaml" p.write_text(standards_file_str) - dirname_return_mock: mocker.MagicMock = mocker.MagicMock() + dirname_return_mock: MagicMock = mocker.MagicMock() mocker.patch.object(os.path, "dirname", return_value=dirname_return_mock) mocker.patch.object(os.path, "join", return_value=str(p)) - load_result: list[HLAStandard] = EasyHLA.load_default_hla_standards() + load_result: LoadedStandards = EasyHLA.load_default_hla_standards() assert load_result == expected_result @@ -2081,10 +2086,12 @@ def test_read_hla_frequencies( # Now try loading these from a file. p = tmp_path / "hla_frequencies.csv" p.write_text(frequencies_str) - dirname_return_mock: mocker.MagicMock = mocker.MagicMock() + dirname_return_mock: MagicMock = mocker.MagicMock() mocker.patch.object(os.path, "dirname", return_value=dirname_return_mock) mocker.patch.object(os.path, "join", return_value=str(p)) - load_result: dict[HLAProteinPair, int] = EasyHLA.load_default_hla_frequencies() + load_result: dict[HLA_LOCUS, dict[HLAProteinPair, int]] = ( + EasyHLA.load_default_hla_frequencies() + ) assert load_result == expected_results @@ -2100,7 +2107,7 @@ def fake_loaded_standards(mocker: MockerFixture) -> LoadedStandards: def test_init_no_defaults( fake_loaded_standards: LoadedStandards, mocker: MockerFixture ): - fake_frequencies: mocker.MagicMock = mocker.MagicMock() + fake_frequencies: MagicMock = mocker.MagicMock() easyhla: EasyHLA = EasyHLA(fake_loaded_standards, fake_frequencies) assert easyhla.tag == fake_loaded_standards["tag"] @@ -2112,12 +2119,12 @@ def test_init_no_defaults( def test_init_all_defaults( fake_loaded_standards: LoadedStandards, mocker: MockerFixture ): - fake_frequencies: mocker.MagicMock = mocker.MagicMock() + fake_frequencies: MagicMock = mocker.MagicMock() - mocker.MagicMock = mocker.patch.object( + _: MagicMock = mocker.patch.object( EasyHLA, "load_default_hla_standards", return_value=fake_loaded_standards ) - mocker.MagicMock = mocker.patch.object( + __: MagicMock = mocker.patch.object( EasyHLA, "load_default_hla_frequencies", return_value=fake_frequencies ) @@ -2152,7 +2159,9 @@ def test_use_config_no_defaults( freq_path: Path = tmp_path / "hla_frequencies.csv" freq_path.write_text(fake_frequencies_str) - easyhla: EasyHLA = EasyHLA.use_config(standards_path, freq_path) + easyhla: EasyHLA = EasyHLA.use_config( + os.fspath(standards_path), os.fspath(freq_path) + ) assert easyhla.tag == fake_stored_standards.tag assert easyhla.last_updated == fake_stored_standards.last_updated assert easyhla.hla_standards == READ_HLA_STANDARDS_TYPICAL_CASE_OUTPUT @@ -2344,7 +2353,9 @@ def test_get_matching_standards( exp_result: list[HLAStandardMatch], ): result: list[HLAStandardMatch] = EasyHLA.get_matching_standards( - seq=sequence, hla_stds=hla_stds, mismatch_threshold=mismatch_threshold + seq=cast(Sequence[int], sequence), + hla_stds=hla_stds, + mismatch_threshold=mismatch_threshold, ) print(result) assert result == exp_result diff --git a/tests/models_test.py b/tests/models_test.py index 4a09d5d..d4821ec 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -103,7 +103,7 @@ def test_sequence_np( two: tuple[int, ...], three: tuple[int, ...], expected_tuple: tuple[int, ...], - expected_array: np.array, + expected_array: np.ndarray, ): hla_standard: HLAStandard = HLAStandard( allele="B*01:23:45", two=two, three=three diff --git a/tests/update_frequency_file_lib_test.py b/tests/update_frequency_file_lib_test.py index 0bb401d..4001894 100644 --- a/tests/update_frequency_file_lib_test.py +++ b/tests/update_frequency_file_lib_test.py @@ -295,8 +295,8 @@ def test_parse_nomenclature( result: tuple[ dict[OldName, NewName], list[str], - list[tuple[str, str]], - list[tuple[str, str]], + list[tuple[str, NewName]], + list[tuple[str, NewName]], ] = parse_nomenclature(fake_text_input) assert result[0] == expected_remapping diff --git a/tests/utils_test.py b/tests/utils_test.py index fe93bcf..7c7ee89 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,7 +1,8 @@ import hashlib from collections.abc import Iterable, Sequence from datetime import datetime -from typing import Optional +from typing import Optional, cast +from unittest.mock import MagicMock import numpy as np import pytest @@ -458,8 +459,8 @@ def test_check_length_hla_type_a( else: with pytest.raises(BadLengthException) as e: check_length("A", seq=sequence, name=name) - assert e.expected_length == expected_length - assert e.actual_length == sequence_length + assert e.value.expected_length == expected_length + assert e.value.actual_length == sequence_length CHECK_LENGTH_HLA_BC_TEST_CASES = [ @@ -797,8 +798,8 @@ def test_check_length_hla_type_b_and_c( else: with pytest.raises(BadLengthException) as e: check_length(locus, seq=sequence, name=name) - assert e.expected_length == expected_length - assert e.actual_length == sequence_length + assert e.value.expected_length == expected_length + assert e.value.actual_length == sequence_length @pytest.mark.parametrize( @@ -837,7 +838,9 @@ def test_calc_padding( ): std = np.array(standard) seq = np.array(sequence) - left_pad, right_pad = calc_padding(std, seq) + left_pad, right_pad = calc_padding( + cast(Sequence[int], std), cast(Sequence[int], seq) + ) assert left_pad == exp_left_pad assert right_pad == exp_right_pad @@ -985,7 +988,7 @@ def test_pad_short( "", 20, 0, - None, + "", id="empty_sequence_and_reference", ), pytest.param( @@ -1601,7 +1604,7 @@ def test_allele_coordinates_sort_key( ], ) def test_collate_standards( - srs: Iterable[SeqRecord], + srs: Sequence[SeqRecord], exon_references: dict[HLA_LOCUS, dict[EXON_NAME, str]], overall_mismatch_threshold: int, acceptable_match_search_threshold: int, @@ -1613,7 +1616,7 @@ def test_collate_standards( expected_logging_calls: list[str], mocker: MockerFixture, ): - mock_logger: Optional[mocker.MagicMock] = None + mock_logger: Optional[MagicMock] = None if use_logging: mock_logger = mocker.MagicMock() result: dict[HLA_LOCUS, list[HLARawStandard]] = collate_standards( @@ -1642,13 +1645,14 @@ def test_collate_standards( assert result[locus] == expected_results[locus] if use_logging: + actual_mock_logger: MagicMock = cast(MagicMock, mock_logger) if len(expected_logging_calls) > 0: - mock_logger.info.assert_has_calls( + actual_mock_logger.info.assert_has_calls( [mocker.call(x) for x in expected_logging_calls], any_order=False, ) else: - mock_logger.assert_not_called() + actual_mock_logger.assert_not_called() @pytest.mark.parametrize( @@ -1806,29 +1810,28 @@ def test_grouped_allele_get_group_name( def test_group_identical_alleles( raw_allele_infos: list[tuple[str, str, str]], use_logging: bool, - expected_result: dict[str, GroupedAllele], + expected_result: list[GroupedAllele], expected_logging_calls: list[str], mocker: MockerFixture, ): - mock_logger: Optional[mocker.MagicMock] = None + mock_logger: Optional[MagicMock] = None if use_logging: mock_logger = mocker.MagicMock() allele_infos: list[HLARawStandard] = [ HLARawStandard(allele=x[0], exon2=x[1], exon3=x[2]) for x in raw_allele_infos ] - result: dict[str, GroupedAllele] = group_identical_alleles( - allele_infos, mock_logger - ) + result: list[GroupedAllele] = group_identical_alleles(allele_infos, mock_logger) assert result == expected_result if use_logging: + actual_mock_logger: MagicMock = cast(MagicMock, mock_logger) if len(expected_logging_calls) > 0: - mock_logger.info.assert_has_calls( + actual_mock_logger.info.assert_has_calls( [mocker.call(x) for x in expected_logging_calls], any_order=False, ) else: - mock_logger.assert_not_called() + actual_mock_logger.assert_not_called() @pytest.mark.parametrize( @@ -2068,17 +2071,15 @@ def test_compute_stored_standard_checksum( def test_stored_hla_standards_error_case(): - stored_stds: StoredHLAStandards = StoredHLAStandards( - tag="0.1.0-dummy-test", - commit_hash="foobar", - last_updated=datetime(2025, 6, 3, 10, 10, 0), - standards={"A": [], "B": [], "C": []}, - ) - # The checksum should be # 221f472fd6986869e33e329d480bc3eca8f3fe6b801e35e2affbff7883735b33 # (checked manually): - stored_stds.checksum = "beeftank" with pytest.raises(ValueError) as e: - stored_stds.compute_compare_checksum() + StoredHLAStandards( + tag="0.1.0-dummy-test", + commit_hash="foobar", + last_updated=datetime(2025, 6, 3, 10, 10, 0), + standards={"A": [], "B": [], "C": []}, + checksum="beeftank", + ) assert "Checksum mismatch" in str(e.value) diff --git a/uv.lock b/uv.lock index e49e630..73f176d 100644 --- a/uv.lock +++ b/uv.lock @@ -342,6 +342,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973, upload-time = "2024-10-09T18:35:44.272Z" }, ] +[[package]] +name = "dunamai" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/194d9a34c4d831c6563d2d990720850f0baef9ab60cb4ad8ae0eff6acd34/dunamai-1.25.0.tar.gz", hash = "sha256:a7f8360ea286d3dbaf0b6a1473f9253280ac93d619836ad4514facb70c0719d1", size = 46155, upload-time = "2025-07-04T19:25:56.082Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/41/04e2a649058b0713b00d6c9bd22da35618bb157289e05d068e51fddf8d7e/dunamai-1.25.0-py3-none-any.whl", hash = "sha256:7f9dc687dd3256e613b6cc978d9daabfd2bb5deb8adc541fc135ee423ffa98ab", size = 27022, upload-time = "2025-07-04T19:25:54.863Z" }, +] + [[package]] name = "easyhla" source = { editable = "." } @@ -354,6 +366,9 @@ dependencies = [ { name = "pyyaml" }, { name = "requests" }, { name = "typer" }, + { name = "types-pyyaml" }, + { name = "types-requests" }, + { name = "uv-dynamic-versioning" }, ] [package.optional-dependencies] @@ -392,6 +407,9 @@ requires-dist = [ { name = "requests", specifier = ">=2.32.3" }, { name = "sqlalchemy", marker = "extra == 'database'", specifier = ">=2.0.40" }, { name = "typer", specifier = ">=0.15.2" }, + { name = "types-pyyaml", specifier = ">=6.0.12.20250516" }, + { name = "types-requests", specifier = ">=2.32.4.20250611" }, + { name = "uv-dynamic-versioning", specifier = ">=0.8.2" }, ] provides-extras = ["database"] @@ -492,6 +510,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/e6/f9d759788518a6248684e3afeb3691f3ab0276d769b6217a1533362298c8/greenlet-3.2.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:d6668caf15f181c1b82fb6406f3911696975cc4c37d782e19cb7ba499e556189", size = 269897, upload-time = "2025-04-22T14:27:14.044Z" }, ] +[[package]] +name = "hatchling" +version = "1.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pathspec" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "trove-classifiers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8f/8a/cc1debe3514da292094f1c3a700e4ca25442489731ef7c0814358816bb03/hatchling-1.27.0.tar.gz", hash = "sha256:971c296d9819abb3811112fc52c7a9751c8d381898f36533bb16f9791e941fd6", size = 54983, upload-time = "2024-12-15T17:08:11.894Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/e7/ae38d7a6dfba0533684e0b2136817d667588ae3ec984c1a4e5df5eb88482/hatchling-1.27.0-py3-none-any.whl", hash = "sha256:d3a2f3567c4f926ea39849cdf924c7e99e6686c9c8e288ae1037c8fa2a5d937b", size = 75794, upload-time = "2024-12-15T17:08:10.364Z" }, +] + [[package]] name = "identify" version = "2.6.8" @@ -861,6 +895,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/ab/ed42acf15bab2e86e5c49fad4aa038315233c4c2d22f41b49faa4d837516/pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267", size = 158280, upload-time = "2024-11-26T15:05:59.428Z" }, ] +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, +] + [[package]] name = "platformdirs" version = "4.3.6" @@ -1385,6 +1428,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257, upload-time = "2024-11-27T22:38:35.385Z" }, ] +[[package]] +name = "tomlkit" +version = "0.13.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/18/0bbf3884e9eaa38819ebe46a7bd25dcd56b67434402b66a58c4b8e552575/tomlkit-0.13.3.tar.gz", hash = "sha256:430cf247ee57df2b94ee3fbe588e71d362a941ebb545dec29b53961d61add2a1", size = 185207, upload-time = "2025-06-05T07:13:44.947Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" }, +] + +[[package]] +name = "trove-classifiers" +version = "2025.8.6.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/21/707af14daa638b0df15b5d5700349e0abdd3e5140069f9ab6e0ccb922806/trove_classifiers-2025.8.6.13.tar.gz", hash = "sha256:5a0abad839d2ed810f213ab133d555d267124ddea29f1d8a50d6eca12a50ae6e", size = 16932, upload-time = "2025-08-06T13:26:26.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/44/323a87d78f04d5329092aada803af3612dd004a64b69ba8b13046601a8c9/trove_classifiers-2025.8.6.13-py3-none-any.whl", hash = "sha256:c4e7fc83012770d80b3ae95816111c32b085716374dccee0d3fbf5c235495f9f", size = 14121, upload-time = "2025-08-06T13:26:25.063Z" }, +] + [[package]] name = "typer" version = "0.15.2" @@ -1409,6 +1470,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/50/65ffad73746f1d8b15992c030e0fd22965fd5ae2c0206dc28873343b3230/types_pytz-2025.1.0.20250204-py3-none-any.whl", hash = "sha256:32ca4a35430e8b94f6603b35beb7f56c32260ddddd4f4bb305fdf8f92358b87e", size = 10059, upload-time = "2025-02-04T02:39:03.899Z" }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/22/59e2aeb48ceeee1f7cd4537db9568df80d62bdb44a7f9e743502ea8aab9c/types_pyyaml-6.0.12.20250516.tar.gz", hash = "sha256:9f21a70216fc0fa1b216a8176db5f9e0af6eb35d2f2932acb87689d03a5bf6ba", size = 17378, upload-time = "2025-05-16T03:08:04.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl", hash = "sha256:8478208feaeb53a34cb5d970c56a7cd76b72659442e733e268a94dc72b2d0530", size = 20312, upload-time = "2025-05-16T03:08:04.019Z" }, +] + +[[package]] +name = "types-requests" +version = "2.32.4.20250611" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/7f/73b3a04a53b0fd2a911d4ec517940ecd6600630b559e4505cc7b68beb5a0/types_requests-2.32.4.20250611.tar.gz", hash = "sha256:741c8777ed6425830bf51e54d6abe245f79b4dcb9019f1622b773463946bf826", size = 23118, upload-time = "2025-06-11T03:11:41.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/ea/0be9258c5a4fa1ba2300111aa5a0767ee6d18eb3fd20e91616c12082284d/types_requests-2.32.4.20250611-py3-none-any.whl", hash = "sha256:ad2fe5d3b0cb3c2c902c8815a70e7fb2302c4b8c1f77bdcd738192cdb3878072", size = 20643, upload-time = "2025-06-11T03:11:40.186Z" }, +] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1436,6 +1518,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680, upload-time = "2025-04-10T15:23:37.377Z" }, ] +[[package]] +name = "uv-dynamic-versioning" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dunamai" }, + { name = "hatchling" }, + { name = "jinja2" }, + { name = "pydantic" }, + { name = "tomlkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/9e/1cf1ddf02e5459076b6fe0e90e1315df461b94c0db6c09b07e5730a0e0fb/uv_dynamic_versioning-0.8.2.tar.gz", hash = "sha256:a9c228a46f5752d99cfead1ed83b40628385cbfb537179488d280853c786bf82", size = 41559, upload-time = "2025-05-02T05:08:30.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/55/a6cffd78511faebf208d4ba1f119d489680668f8d36114564c6f499054b9/uv_dynamic_versioning-0.8.2-py3-none-any.whl", hash = "sha256:400ade6b4a3fc02895c3d24dd0214171e4d60106def343b39ad43143a2615e8c", size = 8851, upload-time = "2025-05-02T05:08:29.33Z" }, +] + [[package]] name = "virtualenv" version = "20.29.2"