In [1]:
import polars as pl
from tcr_format_parsers.common.MHCCodeConverter import (
    HLACodeWebConverter,
    shorten_to_fullname,
    is_fullname,
    DQA_FOR,
    DPA_FOR,
)

supp = pl.read_csv(
    "/tgen_labs/altin/alphafold3/runs/manucript_inp/CRESTA/input/SuppTable1.csv"
)

schema_overrides = {
    "peptideMin": pl.Float32,
    "TCRCDRsMin": pl.Float32,
    "HLAhelicesMin": pl.Float32,
    "peptide_TCR": pl.Int32,
    "HLA_TCR": pl.Int32,
    "cognate": pl.Boolean,
}


manu_tcr = pl.read_parquet(
    "/tgen_labs/altin/alphafold3/runs/manucript_inp/CRESTA/input/extract.parquet"
)
manu_feat = pl.read_parquet(
    "/tgen_labs/altin/alphafold3/runs/manucript_inp/CRESTA/input/extract_feat.parquet",
)


pHLAseqs = pl.read_csv(
    "/tgen_labs/altin/alphafold3/runs/manucript_inp/CRESTA/input/pHLAseqs.csv",
    new_columns=["targetpMHC", "peptide", "mhc_2_seq", "mhc_1_seq"],
)

manu = manu_feat.join(
    manu_tcr,
    left_on=["tcrids", "targetpMHC", "donor"],
    right_on=["tcr_id", "targetpMHC", "donor_num"],
    how="inner",
).join(pHLAseqs, on="targetpMHC", how="inner")


manu = manu.with_columns(
    [pl.col(col).cast(dtype) for col, dtype in schema_overrides.items()]
)


def get_hla_name(target):
    if target == "may" or target == "aav":
        return "DRB1*15:03"
    elif target == "vdl" or target == "hea":
        return "DRB1*11:01"
    elif (
        target == "ist"
        or target == "eqq"
        or target == "dah"
        or target == "alp"
    ):
        return "DQB1*06:02"
    elif target == "nia":
        return "DRB1*01:02"
    elif target == "vrf":
        return "DRB5*01:01"


def infer_hla_chain(mhc_name):

    nullchains = {
        "mhc_1_name": None,
        "mhc_2_name": None,
        "mhc_name_inferred": None,
    }

    if mhc_name.startswith("DRB"):
        fullname = shorten_to_fullname(mhc_name)
        if is_fullname(fullname):
            return {
                "mhc_1_name": "DRA1*01:02",
                "mhc_2_name": fullname,
                "mhc_name_inferred": "chain_1",
            }
        else:
            return nullchains
    elif mhc_name.startswith("DQB"):
        fullname = shorten_to_fullname(mhc_name)
        if is_fullname(fullname):
            b_chain = fullname
            if b_chain in DQA_FOR:
                a_chain = DQA_FOR[b_chain]
                return {
                    "mhc_1_name": a_chain,
                    "mhc_2_name": b_chain,
                    "mhc_name_inferred": "chain_1",
                }
            else:
                return nullchains
                # raise ValueError(f"Unknown DQA chain for DQB chain {b_chain}")
        else:
            return nullchains
    else:
        return nullchains

In [2]:
converter = HLACodeWebConverter()


tmp_struct = pl.Struct(
    {
        "mhc_1_name": pl.String,
        "mhc_2_name": pl.String,
        "mhc_name_inferred": pl.String,
    }
)


manu = (
    manu.with_columns(
        pl.col("targetpMHC")
        .map_elements(get_hla_name, return_dtype=pl.String)
        .alias("mhc_allele_name")
    )
    .with_columns(
        pl.col("mhc_allele_name")
        .map_elements(infer_hla_chain, return_dtype=tmp_struct)
        .alias("chains")
    )
    .unnest("chains")
    .with_columns(
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("alpha").alias("mhc_1_chain"),
        pl.lit("beta").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("II").alias("mhc_class"),
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: converter.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: converter.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_2_seq"),
    )
)

In [3]:
from tcr_format_parsers.common.TCRUtils import hash_tcr_sequence
from tcr_format_parsers.common.TriadUtils import FORMAT_COLS, generate_job_name


CUSTOM_COLS = ["donor"]

manu = generate_job_name(manu)
manu = manu.select(FORMAT_COLS + CUSTOM_COLS)

In [4]:
manu.write_csv(
    "/tgen_labs/altin/alphafold3/runs/manucript_inp/CRESTA/output/cresta.csv"
)

In [6]:
manu.filter(
    (pl.col("peptide") == "IAFASGFRA") & (pl.col("mhc_2_name") == "DQB1*06:02")
)

job_name,cognate,peptide,mhc_class,mhc_1_chain,mhc_1_species,mhc_1_name,mhc_1_seq,mhc_2_chain,mhc_2_species,mhc_2_name,mhc_2_seq,tcr_1_chain,tcr_1_species,tcr_1_seq,tcr_2_chain,tcr_2_species,tcr_2_seq,donor
str,bool,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str
