Setup: class II HLA chain inference


In [22]:
from tcr_format_parsers.common.MHCCodeConverter import (
    shorten_to_fullname,
    is_fullname,
    DQA_FOR,
    DPA_FOR,
)
import warnings


def infer_hla_chain(mhc_1_name, mhc_2_name):

    nullchains = {
        "mhc_1_name": None,
        "mhc_2_name": None,
        "mhc_name_inferred": None,
    }

    if mhc_2_name.startswith("DRB"):
        fullname = shorten_to_fullname(mhc_2_name)
        if is_fullname(fullname):
            return {
                # use 0102 sicne sequence in uniprot
                # mutation outside top region
                "mhc_1_name": "DRA*01:02",
                "mhc_2_name": fullname,
                # only one option- we don't count this as inferred
                "mhc_name_inferred": "neither",
            }

        else:
            warnings.warn(f"Could not find fullname for {mhc_2_name}")
            return nullchains
    elif mhc_2_name.startswith("DQB"):
        fullname = shorten_to_fullname(mhc_2_name)
        if is_fullname(fullname):

            b_chain = fullname
            a_chain = (
                shorten_to_fullname(mhc_1_name)
                if mhc_1_name is not None
                else None
            )
            if mhc_1_name is None or not is_fullname(a_chain):
                if b_chain in DQA_FOR:
                    a_chain = DQA_FOR[b_chain]
                    return {
                        "mhc_1_name": a_chain,
                        "mhc_2_name": b_chain,
                        "mhc_name_inferred": "chain_1",
                    }
                else:
                    warnings.warn(f"Could not find DQA chain for {b_chain}")
                    return nullchains
            else:
                return {
                    "mhc_1_name": a_chain,
                    "mhc_2_name": b_chain,
                    "mhc_name_inferred": "neither",
                }
        else:
            return nullchains

    elif mhc_2_name.startswith("DPB"):
        fullname = shorten_to_fullname(mhc_2_name)
        if is_fullname(fullname):
            b_chain = fullname
            a_chain = (
                shorten_to_fullname(mhc_1_name)
                if mhc_1_name is not None
                else None
            )
            if mhc_1_name is None or not is_fullname(a_chain):

                if b_chain in DPA_FOR:
                    a_chain = DPA_FOR[b_chain]
                    return {
                        "mhc_1_name": a_chain,
                        "mhc_2_name": b_chain,
                        "mhc_name_inferred": "chain_1",
                    }
                else:
                    warnings.warn(f"Could not find DPA chain for {b_chain}")
                    return nullchains
            else:
                return {
                    "mhc_1_name": a_chain,
                    "mhc_2_name": b_chain,
                    "mhc_name_inferred": "neither",
                }
        else:
            return nullchains
    else:
        return nullchains

### 1. Import triad data from IMMREP25 fork


In [41]:
import polars as pl

schema_overrides = {
    "references": pl.String,
    "receptor_id": pl.String,
}

iedb_human_I = pl.read_csv(
    "raw/HUMAN_I/immrep_IEDB.csv", schema_overrides=schema_overrides
)
iedb_human_II = pl.read_csv(
    "raw/HUMAN_II/immrep_IEDB.csv", schema_overrides=schema_overrides
)
iedb_mouse_I = pl.read_csv(
    "raw/MOUSE_I/immrep_IEDB.csv", schema_overrides=schema_overrides
)
iedb_mouse_II = pl.read_csv(
    "raw/MOUSE_II/immrep_IEDB.csv", schema_overrides=schema_overrides
)

vdjdb_human_I = pl.read_csv(
    "raw/HUMAN_I/vdjdb_pos_human_I.csv", schema_overrides=schema_overrides
)
vdjdb_human_II = pl.read_csv(
    "raw/HUMAN_II/vdjdb_pos_human_II.csv", schema_overrides=schema_overrides
)
vdjdb_mouse_I = pl.read_csv(
    "raw/MOUSE_I/vdjdb_pos_mouse_I.csv", schema_overrides=schema_overrides
)
vdjdb_mouse_II = pl.read_csv(
    "raw/MOUSE_II/vdjdb_pos_mouse_II.csv", schema_overrides=schema_overrides
)

In [137]:
import polars as pl
from tcr_format_parsers.common.MHCCodeConverter import (
    B2M_HUMAN_SEQ,
    HLACodeWebConverter,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)

human_conv = HLACodeWebConverter()

iedb_human_I = (
    iedb_human_I.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.lit("heavy").alias("mhc_1_chain"),
        pl.lit("light").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit(B2M_HUMAN_SEQ).alias("mhc_2_seq"),
        pl.lit("B2M").alias("mhc_2_name"),
        pl.lit("I").alias("mhc_class"),
        pl.col("mhc_1_name").str.split_exact("HLA-", 1).alias("split_parts"),
        pl.lit(True).alias("cognate"),
        pl.col("receptor_id").str.split(",").alias("receptor_id"),
        pl.col("references").str.split(",").alias("references"),
    )
    .select(pl.exclude("mhc_1_name"))
    .unnest("split_parts")
    .rename(
        {
            "field_0": "tmp",
            "field_1": "mhc_1_name",
        }
    )
    .select(pl.exclude("tmp"))
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: human_conv.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq")
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(), pl.col("tcr_2_seq").is_not_null()
    )
)

iedb_human_I = (
    generate_job_name(iedb_human_I)
    .select(FORMAT_COLS + ["receptor_id", "references"])
    .unique()
)

iedb_human_I_negs = (
    generate_negatives(iedb_human_I)
    .with_columns(
        pl.lit(None).alias("receptor_id"),
        pl.lit(None).alias("references"),
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS + ["receptor_id", "references"])
)

iedb_human_I = pl.concat([iedb_human_I, iedb_human_I_negs])

In [None]:
iedb_human_I.select(FORMAT_COLS).write_csv("iedb/human_I/iedb_human_I.csv")
iedb_human_I.write_parquet("iedb/human_I/iedb_human_I.parquet")

In [145]:
iedb_human_II = (
    iedb_human_II.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.lit("alpha").alias("mhc_1_chain"),
        pl.lit("beta").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit("II").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
        pl.col("receptor_id").str.split(",").alias("receptor_id"),
        pl.col("references").str.split(",").alias("references"),
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(), pl.col("tcr_2_seq").is_not_null()
    )
)

iedb_human_II = (
    iedb_human_II.with_columns(
        pl.col("mhc_1_name").str.split("/").alias("split_parts")
    )
    .with_columns(
        pl.when(pl.col("split_parts").list.len() == 2)
        .then(
            pl.struct(
                pl.col("split_parts")
                .list.get(0, null_on_oob=True)
                .str.slice(4)
                .alias("mhc_1_name"),
                pl.col("split_parts")
                .list.get(1, null_on_oob=True)
                .alias("mhc_2_name"),
            )
        )
        .otherwise(
            pl.struct(
                pl.lit(None).alias("mhc_1_name"),
                pl.col("split_parts")
                .list.get(0)
                .str.slice(4)
                .alias("mhc_2_name"),
            )
        )
        .alias("mhc_struct")
    )
    .select(pl.exclude("mhc_1_name"))
    .with_columns(
        pl.col("mhc_struct")
        .map_elements(
            lambda x: infer_hla_chain(x["mhc_1_name"], x["mhc_2_name"]),
            return_dtype=pl.Struct,
        )
        .alias("chains")
    )
    .unnest("chains")
    .filter(
        (pl.col("mhc_1_name").is_not_null())
        & (pl.col("mhc_2_name").is_not_null())
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: human_conv.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: human_conv.get_sequence(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_2_seq"),
    )
)


iedb_human_II = (
    generate_job_name(iedb_human_II)
    .select(FORMAT_COLS + ["receptor_id", "references"])
    .unique()
)

iedb_human_II_negs = (
    generate_negatives(iedb_human_II)
    .with_columns(
        pl.lit(None).alias("receptor_id"),
        pl.lit(None).alias("references"),
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS + ["receptor_id", "references"])
)

iedb_human_II = pl.concat([iedb_human_II, iedb_human_II_negs])



In [146]:
iedb_human_II.select(FORMAT_COLS).write_csv("iedb/human_II/iedb_human_II.csv")
iedb_human_II.write_parquet("iedb/human_II/iedb_human_II.parquet")

In [9]:
from tcr_format_parsers.common.MHCCodeConverter import (
    H2CodeDictConverter,
    H2_I_LIGHT_DICT,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)

mouse_conv = H2CodeDictConverter()

iedb_mouse_I = (
    iedb_mouse_I.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .filter(
        pl.col("mhc_1_name").str.starts_with("H2-"),
        ~pl.col("mhc_1_name").str.starts_with("H2-I"),
        ~pl.col("mhc_1_name").str.contains(" "),
    )
    .with_columns(
        pl.lit("heavy").alias("mhc_1_chain"),
        pl.lit("light").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("mouse").alias("tcr_1_species"),
        pl.lit("mouse").alias("tcr_2_species"),
        pl.lit("mouse").alias("mhc_1_species"),
        pl.lit("mouse").alias("mhc_2_species"),
        pl.lit(H2_I_LIGHT_DICT["B2M"]).alias("mhc_2_seq"),
        pl.lit("B2M").alias("mhc_2_name"),
        pl.lit("I").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
        pl.col("receptor_id").str.split(",").alias("receptor_id"),
        pl.col("references").str.split(",").alias("references"),
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: mouse_conv.get_sequence(x, "heavy", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq")
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(),
        pl.col("tcr_2_seq").is_not_null(),
        ~pl.col("tcr_1_seq").str.contains(r"\*|X"),
        ~pl.col("tcr_2_seq").str.contains(r"\*|X"),
    )
)

iedb_mouse_I = (
    generate_job_name(iedb_mouse_I)
    .select(FORMAT_COLS + ["receptor_id", "references"])
    .unique()
)

iedb_mouse_I_negs = (
    generate_negatives(iedb_mouse_I)
    .with_columns(
        pl.lit(None).alias("receptor_id"),
        pl.lit(None).alias("references"),
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS + ["receptor_id", "references"])
)

iedb_mouse_I = pl.concat([iedb_mouse_I, iedb_mouse_I_negs])

In [10]:
iedb_mouse_I.select(FORMAT_COLS).write_csv("iedb/mouse_I/iedb_mouse_I.csv")
iedb_mouse_I.write_parquet("iedb/mouse_I/iedb_mouse_I.parquet")

In [19]:
from tcr_format_parsers.common.MHCCodeConverter import (
    H2CodeDictConverter,
    H2_I_LIGHT_DICT,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)
import warnings

mouse_conv = H2CodeDictConverter()


def mouse_conv_wrapper(x, chain, **kwargs):
    try:
        seq = mouse_conv.get_sequence(x, chain, **kwargs)
        return seq
    except ValueError as e:
        warnings.warn(str(e))
        return None


iedb_mouse_II = (
    iedb_mouse_II.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .filter(
        pl.col("mhc_1_name").str.starts_with("H2-I"),
    )
    .with_columns(
        pl.lit("alpha").alias("mhc_1_chain"),
        pl.lit("beta").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("mouse").alias("tcr_1_species"),
        pl.lit("mouse").alias("tcr_2_species"),
        pl.lit("mouse").alias("mhc_1_species"),
        pl.lit("mouse").alias("mhc_2_species"),
        pl.col("mhc_1_name").alias("mhc_2_name"),
        pl.lit("II").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
        pl.col("receptor_id").str.split(",").alias("receptor_id"),
        pl.col("references").str.split(",").alias("references"),
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: mouse_conv_wrapper(x, "alpha", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: mouse_conv_wrapper(x, "beta", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_2_seq"),
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(),
        pl.col("tcr_2_seq").is_not_null(),
        pl.col("mhc_1_seq").is_not_null(),
        pl.col("mhc_2_seq").is_not_null(),
    )
)

iedb_mouse_II = (
    generate_job_name(iedb_mouse_II)
    .select(FORMAT_COLS + ["receptor_id", "references"])
    .unique()
)

iedb_mouse_II_negs = (
    generate_negatives(iedb_mouse_II)
    .with_columns(
        pl.lit(None).alias("receptor_id"),
        pl.lit(None).alias("references"),
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS + ["receptor_id", "references"])
)

iedb_mouse_II = pl.concat([iedb_mouse_II, iedb_mouse_II_negs])



In [20]:
iedb_mouse_II.select(FORMAT_COLS).write_csv("iedb/mouse_II/iedb_mouse_II.csv")
iedb_mouse_II.write_parquet("iedb/mouse_II/iedb_mouse_II.parquet")

In [None]:
import polars as pl
from tcr_format_parsers.common.MHCCodeConverter import (
    B2M_HUMAN_SEQ,
    HLACodeWebConverter,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)

human_conv = HLACodeWebConverter()


def get_sequence_wrapper(seq, **kwargs):
    try:
        return human_conv.get_sequence(seq, **kwargs)
    except ValueError as e:
        warnings.warn(f"Error in sequence {seq}: {e}")
        return None


vdjdb_human_I = (
    vdjdb_human_I.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.lit("heavy").alias("mhc_1_chain"),
        pl.lit("light").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit(B2M_HUMAN_SEQ).alias("mhc_2_seq"),
        pl.lit("B2M").alias("mhc_2_name"),
        pl.lit("I").alias("mhc_class"),
        pl.col("mhc_1_name").str.split_exact("HLA-", 1).alias("split_parts"),
        pl.lit(True).alias("cognate"),
    )
    .select(pl.exclude("mhc_1_name"))
    .unnest("split_parts")
    .rename(
        {
            "field_0": "tmp",
            "field_1": "mhc_1_name",
        }
    )
    .select(pl.exclude("tmp"))
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: get_sequence_wrapper(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq")
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(),
        pl.col("tcr_2_seq").is_not_null(),
        pl.col("mhc_1_seq").is_not_null(),
    )
)

vdjdb_human_I = generate_job_name(vdjdb_human_I).select(FORMAT_COLS).unique()

vdjdb_human_I_negs = (
    generate_negatives(vdjdb_human_I)
    .with_columns(
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS)
)

vdjdb_human_I = pl.concat([vdjdb_human_I, vdjdb_human_I_negs])

In [135]:
vdjdb_human_I.select(FORMAT_COLS).write_csv("vdjdb/human_I/vdjdb_human_I.csv")
vdjdb_human_I.write_parquet("vdjdb/human_I/vdjdb_human_I.parquet")

In [25]:
vdjdb_human_II = (
    vdjdb_human_II.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.lit("alpha").alias("mhc_1_chain"),
        pl.lit("beta").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("human").alias("tcr_1_species"),
        pl.lit("human").alias("tcr_2_species"),
        pl.lit("human").alias("mhc_1_species"),
        pl.lit("human").alias("mhc_2_species"),
        pl.lit("II").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(), pl.col("tcr_2_seq").is_not_null()
    )
)

vdjdb_human_II = (
    vdjdb_human_II.with_columns(
        pl.col("mhc_1_name").str.split("/").alias("split_parts")
    )
    .with_columns(
        pl.when(pl.col("split_parts").list.len() == 2)
        .then(
            pl.struct(
                pl.col("split_parts")
                .list.get(0, null_on_oob=True)
                .str.slice(4)
                .alias("mhc_1_name"),
                pl.col("split_parts")
                .list.get(1, null_on_oob=True)
                .str.slice(4)
                .alias("mhc_2_name"),
            )
        )
        .otherwise(
            pl.struct(
                pl.lit(None).alias("mhc_1_name"),
                pl.col("split_parts")
                .list.get(0)
                .str.slice(4)
                .alias("mhc_2_name"),
            )
        )
        .alias("mhc_struct")
    )
    .select(pl.exclude("mhc_1_name"))
    .with_columns(
        pl.col("mhc_struct")
        .map_elements(
            lambda x: infer_hla_chain(x["mhc_1_name"], x["mhc_2_name"]),
            return_dtype=pl.Struct,
        )
        .alias("chains")
    )
    .unnest("chains")
    .filter(
        (pl.col("mhc_1_name").is_not_null())
        & (pl.col("mhc_2_name").is_not_null())
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: get_sequence_wrapper(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: get_sequence_wrapper(x, top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_2_seq"),
    )
    .filter(
        pl.col("mhc_1_seq").is_not_null(),
        pl.col("mhc_2_seq").is_not_null(),
    )
)


vdjdb_human_II = generate_job_name(vdjdb_human_II).select(FORMAT_COLS).unique()

vdjdb_human_II_negs = (
    generate_negatives(vdjdb_human_II)
    .with_columns(
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS)
)

vdjdb_human_II = pl.concat([vdjdb_human_II, vdjdb_human_II_negs])

ComputeError: NameError: name 'get_sequence_wrapper' is not defined

In [169]:
vdjdb_human_II.select(FORMAT_COLS).write_csv(
    "vdjdb/human_II/vdjdb_human_II.csv"
)
vdjdb_human_II.write_parquet("vdjdb/human_II/vdjdb_human_II.parquet")

In [11]:
from tcr_format_parsers.common.MHCCodeConverter import (
    H2CodeDictConverter,
    H2_I_LIGHT_DICT,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)

mouse_conv = H2CodeDictConverter()


def fix_hla_name(x):
    if x.startswith("H2-"):
        return x
    elif x.startswith("H-2"):
        fixed = "H2-" + x[3:]
        fixed = fixed[:-1] + fixed[-1].lower()
        return fixed
    else:
        return x


vdjdb_mouse_I = (
    vdjdb_mouse_I.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(fix_hla_name, return_dtype=pl.String)
        .alias("mhc_1_name")
    )
    .filter(
        (pl.col("mhc_1_name").str.starts_with("H2-"))
        | (pl.col("mhc_1_name").str.starts_with("H-2")),
        ~pl.col("mhc_1_name").str.starts_with("H2-I"),
        ~pl.col("mhc_1_name").str.contains(" "),
    )
    .with_columns(
        pl.lit("heavy").alias("mhc_1_chain"),
        pl.lit("light").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("mouse").alias("tcr_1_species"),
        pl.lit("mouse").alias("tcr_2_species"),
        pl.lit("mouse").alias("mhc_1_species"),
        pl.lit("mouse").alias("mhc_2_species"),
        pl.lit(H2_I_LIGHT_DICT["B2M"]).alias("mhc_2_seq"),
        pl.lit("B2M").alias("mhc_2_name"),
        pl.lit("I").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: mouse_conv.get_sequence(x, "heavy", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq")
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(),
        pl.col("tcr_2_seq").is_not_null(),
        ~pl.col("tcr_1_seq").str.contains(r"\*|X"),
        ~pl.col("tcr_2_seq").str.contains(r"\*|X"),
    )
)

vdjdb_mouse_I = generate_job_name(vdjdb_mouse_I).select(FORMAT_COLS).unique()

vdjdb_mouse_I_negs = (
    generate_negatives(vdjdb_mouse_I)
    .with_columns(
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS)
)

vdjdb_mouse_I = pl.concat([vdjdb_mouse_I, vdjdb_mouse_I_negs])

In [12]:
vdjdb_mouse_I.select(FORMAT_COLS).write_csv("vdjdb/mouse_I/vdjdb_mouse_I.csv")
vdjdb_mouse_I.write_parquet("vdjdb/mouse_I/vdjdb_mouse_I.parquet")

In [22]:
vdjdb_mouse_II.filter(~pl.col("HLA").str.contains("HLA")).select(
    "HLA"
).to_series().value_counts()

HLA,count
str,u32
"""H-2Eb1/H-2Eb1""",7
"""H-2Aa/H-2Aa""",17


In [37]:
from tcr_format_parsers.common.MHCCodeConverter import (
    H2CodeDictConverter,
    H2_I_LIGHT_DICT,
)
from tcr_format_parsers.common.TriadUtils import (
    generate_job_name,
    FORMAT_COLS,
    generate_negatives,
)
import warnings

vdjdb_translate_dict = {
    "H-2Eb1": "H2-IEk",
    "H-2Aa": "H2-IAb",
}


def vdjdb_chain_translate(mhc_name):
    if mhc_name in vdjdb_translate_dict:
        return vdjdb_translate_dict[mhc_name]
    else:
        return mhc_name


def mouse_conv_wrapper(x, chain, **kwargs):
    try:
        seq = mouse_conv.get_sequence(x, chain, **kwargs)
        return seq
    except ValueError as e:
        warnings.warn(str(e))
        return None


mouse_conv = H2CodeDictConverter()

vdjdb_mouse_II = (
    vdjdb_mouse_II.rename(
        {
            "Peptide": "peptide",
            "TCRb": "tcr_2_seq",
            "TCRa": "tcr_1_seq",
            "HLA": "mhc_1_name",
        }
    )
    .filter(
        ~pl.col("mhc_1_name").str.contains("HLA"),
    )
    .with_columns(
        pl.lit("alpha").alias("mhc_1_chain"),
        pl.lit("beta").alias("mhc_2_chain"),
        pl.lit("alpha").alias("tcr_1_chain"),
        pl.lit("beta").alias("tcr_2_chain"),
        pl.lit("mouse").alias("tcr_1_species"),
        pl.lit("mouse").alias("tcr_2_species"),
        pl.lit("mouse").alias("mhc_1_species"),
        pl.lit("mouse").alias("mhc_2_species"),
        pl.col("mhc_1_name").str.split("/").list.get(1).alias("mhc_1_name"),
        pl.col("mhc_1_name").str.split("/").list.get(1).alias("mhc_2_name"),
        pl.lit("II").alias("mhc_class"),
        pl.lit(True).alias("cognate"),
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: vdjdb_chain_translate(x),
            return_dtype=pl.String,
        )
        .alias("mhc_1_name"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: vdjdb_chain_translate(x),
            return_dtype=pl.String,
        )
        .alias("mhc_2_name"),
    )
    .with_columns(
        pl.col("mhc_1_name")
        .map_elements(
            lambda x: mouse_conv_wrapper(x, "alpha", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_1_seq"),
        pl.col("mhc_2_name")
        .map_elements(
            lambda x: mouse_conv_wrapper(x, "beta", top_only=True),
            return_dtype=pl.String,
        )
        .alias("mhc_2_seq"),
    )
    .filter(
        pl.col("tcr_1_seq").is_not_null(),
        pl.col("tcr_2_seq").is_not_null(),
        pl.col("mhc_1_seq").is_not_null(),
        pl.col("mhc_2_seq").is_not_null(),
    )
)

vdjdb_mouse_II = generate_job_name(vdjdb_mouse_II).select(FORMAT_COLS).unique()

vdjdb_mouse_II_negs = (
    generate_negatives(vdjdb_mouse_II)
    .with_columns(
        pl.lit(False).alias("cognate"),
    )
    .select(FORMAT_COLS)
)

vdjdb_mouse_II = pl.concat([vdjdb_mouse_II, vdjdb_mouse_II_negs])

In [38]:
vdjdb_mouse_II.write_csv("vdjdb/mouse_II/vdjdb_mouse_II.csv")
vdjdb_mouse_II.write_parquet("vdjdb/mouse_II/vdjdb_mouse_II.parquet")