### Setup: helper methods


In [1]:
import requests
from Bio import SeqIO
from io import StringIO
import polars as pl
from tcr_format_parsers.common.MHCCodeConverter import (
    HLASequenceDBConverter,
    H2SequenceDictConverter,
)
from tcr_format_parsers.common.TCRUtils import standardize_tcr


def format_pdb_df(df):
    df = df.with_columns(
        pl.when(pl.col("mhc_type") == "MH1")
        .then(pl.lit("I"))
        .when(pl.col("mhc_type") == "MH2")
        .then(pl.lit("II"))
        .otherwise(None)
        .alias("mhc_class"),
    ).filter(pl.col("mhc_class").is_not_null())

    df = df.filter(
        (pl.col("mhc_chain1").is_not_null())
        & (pl.col("mhc_chain2").is_not_null())
    )

    df = df.group_by("pdb").agg(
        pl.col("Bchain").first(),
        pl.col("Achain").first(),
        pl.col("mhc_chain1").first(),
        pl.col("mhc_chain2").first(),
        pl.col("antigen_chain").first(),
        pl.col("mhc_class").first(),
        pl.col("mhc_chain1_organism").first().alias("mhc_1_species"),
        pl.col("mhc_chain2_organism").first().alias("mhc_2_species"),
        pl.col("alpha_organism").first().alias("tcr_1_species"),
        pl.col("beta_organism").first().alias("tcr_2_species"),
    )

    df = df.with_columns(
        pl.when(pl.col("mhc_1_species") == "homo sapiens")
        .then(pl.lit("human"))
        .when(pl.col("mhc_1_species") == "mus musculus")
        .then(pl.lit("mouse"))
        .otherwise(None)
        .alias("mhc_1_species"),
        pl.when(pl.col("mhc_2_species") == "homo sapiens")
        .then(pl.lit("human"))
        .when(pl.col("mhc_2_species") == "mus musculus")
        .then(pl.lit("mouse"))
        .otherwise(None)
        .alias("mhc_2_species"),
        pl.when(pl.col("tcr_1_species") == "homo sapiens")
        .then(pl.lit("human"))
        .when(pl.col("tcr_1_species") == "mus musculus")
        .then(pl.lit("mouse"))
        .otherwise(None)
        .alias("tcr_1_species"),
        pl.when(pl.col("tcr_2_species") == "homo sapiens")
        .then(pl.lit("human"))
        .when(pl.col("tcr_2_species") == "mus musculus")
        .then(pl.lit("mouse"))
        .otherwise(None)
        .alias("tcr_2_species"),
    ).filter(
        (pl.col("mhc_1_species").is_not_null())
        & (pl.col("mhc_2_species").is_not_null())
        & (pl.col("tcr_1_species").is_not_null())
        & (pl.col("tcr_2_species").is_not_null())
    )

    df = df.with_columns(
        pl.when(pl.col("mhc_class") == "II")
        .then(pl.lit("alpha"))
        .otherwise(pl.lit("heavy"))
        .alias("mhc_1_chain"),
        pl.when(pl.col("mhc_class") == "II")
        .then(pl.lit("beta"))
        .otherwise(pl.lit("light"))
        .alias("mhc_2_chain"),
        pl.lit(True).alias("cognate"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
    )

    df = df.with_columns(
        pl.col("antigen_chain")
        .str.split("|")
        .list.first()
        .str.strip_chars()
        .alias("antigen_chain")
    )

    return df


SEQ_STRUCT = pl.Struct(
    {
        "peptide": pl.String,
        "mhc_1_seq": pl.String,
        "mhc_2_seq": pl.String,
        "tcr_1_seq": pl.String,
        "tcr_2_seq": pl.String,
    }
)


def parse_chain(chain):
    if "[" in chain:
        return chain.split("[auth ")[1][0]
    else:
        return chain.replace(" ", "")


def parse_fasta_description(description):
    chain_token = description.split("|")[1]

    if chain_token.startswith("Chain "):
        return list(parse_chain(chain_token.split("Chain ")[1]))
    else:
        chains = chain_token.split("Chains ")[1].split(",")
        chain_list = [parse_chain(chain) for chain in chains]

        return chain_list


def get_fasta_seq(
    pdb_id,
    antigen_chain_id,
    mhc_chain1_id,
    mhc_chain2_id,
    Achain_id,
    Bchain_id,
):
    r = requests.get("https://www.rcsb.org/fasta/entry/" + pdb_id)

    r.raise_for_status()

    fasta_sequences = SeqIO.parse(StringIO(r.text), "fasta")

    seq_dict = {}
    for fasta in fasta_sequences:
        chains = parse_fasta_description(fasta.description)
        for chain in chains:
            seq_dict[chain] = str(fasta.seq)

    return {
        "peptide": seq_dict[antigen_chain_id],
        "mhc_1_seq": seq_dict[mhc_chain1_id],
        "mhc_2_seq": seq_dict[mhc_chain2_id],
        "tcr_1_seq": seq_dict[Achain_id],
        "tcr_2_seq": seq_dict[Bchain_id],
    }


def format_seqs(df):
    df = df.with_columns(
        pl.struct(
            pl.col("pdb"),
            pl.col("Bchain"),
            pl.col("Achain"),
            pl.col("antigen_chain"),
            pl.col("mhc_chain1"),
            pl.col("mhc_chain2"),
        )
        .map_elements(
            lambda x: get_fasta_seq(
                x["pdb"],
                x["antigen_chain"],
                x["mhc_chain1"],
                x["mhc_chain2"],
                x["Achain"],
                x["Bchain"],
            ),
            return_dtype=SEQ_STRUCT,
        )
        .alias("chain_seqs"),
    ).unnest("chain_seqs")

    return df


def infer_correct_mhc(row, human_conv, mouse_conv):
    mhc1 = row["mhc_1_seq"]
    mhc2 = row["mhc_2_seq"]

    if row["mhc_1_species"] == "human":
        mhc_1_inf = human_conv.get_mhc_allele(
            mhc1, chain=row["mhc_1_chain"], top_only=True
        )
    else:
        mhc_1_inf = mouse_conv.get_mhc_allele(
            mhc1, chain=row["mhc_1_chain"], top_only=True
        )

    if row["mhc_2_species"] == "human":
        mhc_2_inf = human_conv.get_mhc_allele(
            mhc2, chain=row["mhc_2_chain"], top_only=True
        )
    else:
        mhc_2_inf = mouse_conv.get_mhc_allele(
            mhc2, chain=row["mhc_2_chain"], top_only=True
        )

    new_row = row.copy()

    new_row["mhc_1_seq"] = mhc_1_inf["seq"]
    new_row["mhc_1_name"] = mhc_1_inf["name"]
    new_row["mhc_2_seq"] = mhc_2_inf["seq"]
    new_row["mhc_2_name"] = mhc_2_inf["name"]
    return pl.DataFrame(new_row)


def shorten_tcr_to_vregion(row):
    new_row = row.copy()
    new_row["tcr_1_seq"] = standardize_tcr(
        row["tcr_1_seq"], row["tcr_1_chain"], row["tcr_1_species"]
    )
    new_row["tcr_2_seq"] = standardize_tcr(
        row["tcr_2_seq"], row["tcr_2_chain"], row["tcr_2_species"]
    )
    return pl.DataFrame(new_row)

### 1. Import triads from STCRpred


In [2]:
import polars as pl
from pathlib import Path

schema_overrides = {
    "Gchain": pl.String,
    "Dchain": pl.String,
}
null_values = ["NA", "unknown"]

pdb_human_I = pl.read_csv(
    "raw/humanI.tsv",
    schema_overrides=schema_overrides,
    null_values=null_values,
    separator="\t",
)

pdb_human_I = format_pdb_df(pdb_human_I)


pdb_human_II = pl.read_csv(
    "raw/humanII.tsv",
    schema_overrides=schema_overrides,
    null_values=null_values,
    separator="\t",
)

pdb_human_II = format_pdb_df(pdb_human_II)

pdb_mouse_I = pl.read_csv(
    "raw/mouseI.tsv",
    schema_overrides=schema_overrides,
    null_values=null_values,
    separator="\t",
)

pdb_mouse_I = format_pdb_df(pdb_mouse_I)

pdb_mouse_II = pl.read_csv(
    "raw/mouseII.tsv",
    schema_overrides=schema_overrides,
    null_values=null_values,
    separator="\t",
)

pdb_mouse_II = format_pdb_df(pdb_mouse_II)

In [3]:
pdb_human_I = format_seqs(pdb_human_I)
pdb_human_II = format_seqs(pdb_human_II)
pdb_mouse_I = format_seqs(pdb_mouse_I)
pdb_mouse_II = format_seqs(pdb_mouse_II)

In [None]:
from tcr_format_parsers.common.MHCCodeConverter import (
    HLASequenceDBConverter,
    H2SequenceDictConverter,
)
from mdaf3.FeatureExtraction import split_apply_combine, serial_apply

human_conv = HLASequenceDBConverter("/tgen_labs/altin/alphafold3/IMGTHLA")
mouse_conv = H2SequenceDictConverter()

pdb_human_I = split_apply_combine(
    pdb_human_I, infer_correct_mhc, human_conv, mouse_conv
)

pdb_human_I = split_apply_combine(pdb_human_I, shorten_tcr_to_vregion)