In [1]:
import polars as pl
from tcr_format_parsers.common.MHCCodeConverter import (
    B2M_HUMAN_SEQ,
    HLACodeWebConverter,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
)

test_dat = pl.read_csv(
    "/tgen_labs/altin/alphafold3/runs/challenge_inp/test/test.csv"
)

CUSTOM_COLS = ["ID", "original"]


conv = HLACodeWebConverter()

test_dat = (
    test_dat.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.lit("heavy").alias("mhc_1_chain"),
        pl.lit("light").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit(B2M_HUMAN_SEQ).alias("mhc_2_seq"),
        pl.lit("B2M").alias("mhc_2_name"),
        pl.lit("I").alias("mhc_class"),
        pl.col("mhc_1_name").str.split_exact("HLA-", 1).alias("split_parts"),
        pl.lit(None).alias("cognate"),
    )
    .select(pl.exclude("mhc_1_name"))
    .unnest("split_parts")
    .rename(
        {
            "field_0": "tmp",
            "field_1": "mhc_1_name",
        }
    )
    .select(pl.exclude("tmp"))
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: conv.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.lit(True).alias("original"),
    )
)

test_dat = generate_job_name(test_dat)
test_dat = test_dat.select(FORMAT_COLS + CUSTOM_COLS)

In [2]:
from tcr_format_parsers.common.TriadUtils import (
    FORMAT_TCR_COLS,
    FORMAT_ANTIGEN_COLS,
)


unique_tcrs = test_dat.select(FORMAT_TCR_COLS).unique()
unique_pmhc = test_dat.select(FORMAT_ANTIGEN_COLS).unique()

unseen = unique_tcrs.join(unique_pmhc, how="cross").join(
    test_dat, on=FORMAT_TCR_COLS + FORMAT_ANTIGEN_COLS, how="anti"
)

unseen = unseen.with_columns(
    pl.lit(None).alias("cognate"),
    pl.lit(False).alias("original"),
    pl.lit(None).alias("ID"),
)
unseen = generate_job_name(unseen)
unseen = unseen.select(FORMAT_COLS + CUSTOM_COLS)

test_dat = pl.concat([test_dat, unseen])

In [11]:
test_dat.write_csv(
    "/tgen_labs/altin/alphafold3/runs/challenge_inp/test/output/test_unstd.csv"
)

## Standardized run


In [4]:
from tcr_format_parsers.common.TCRUtils import (
    HUMAN_TRAC_SEQ,
    HUMAN_TRAC_TOP,
    HUMAN_TRBC1_SEQ,
    HUMAN_TRBC1_TOP,
)

test_dat_std = test_dat.with_columns(
    pl.col("tcr_1_seq")
    .map_elements(
        lambda x: x + HUMAN_TRAC_SEQ[HUMAN_TRAC_TOP], return_dtype=pl.String
    )
    .alias("tcr_1_seq"),
    pl.col("tcr_2_seq")
    .map_elements(
        lambda x: x + HUMAN_TRBC1_SEQ[HUMAN_TRBC1_TOP], return_dtype=pl.String
    )
    .alias("tcr_2_seq"),
)

In [9]:
test_dat_std.write_csv(
    "/tgen_labs/altin/alphafold3/runs/challenge_inp/test/output/test_std.csv"
)