# Embed known binder sequences with each fold's fine-tuned language model, and apply existing scaling and PCA transformations

See `scripts/off_peak.run_embedding_fine_tuned.and_scale.py`.

Recall that we have a separate fine-tuned language model for each train-smaller set. So treat this as an extension of the test set. For each test fold ID, and apply the language model, scaling, and PCA transformations trained on that fold's train-smaller set.

In [1]:
import numpy as np
import pandas as pd
import joblib
import choosegpu
from malid import config, apply_embedding, interpretation
from malid.datamodels import GeneLocus

In [2]:
# Embed with GPU
choosegpu.configure_gpu(enable=True)

['GPU-079f8d58-9984-1b40-b487-b558c5a6393c']

In [3]:
config.embedder.name

'unirep_fine_tuned'

In [4]:
def process(gene_locus):
    print(gene_locus)
    GeneLocus.validate_single_value(gene_locus)
    df, cluster_centroids_by_supergroup = interpretation.load_reference_dataset(
        gene_locus
    )
    print(df.shape)

    # total number of clusters across all data
    df["global_resulting_cluster_ID"].nunique()
    # a number of sequences were joined into a single cluster
    df["global_resulting_cluster_ID"].value_counts()
    # how many sequences were merged
    (df["global_resulting_cluster_ID"].value_counts() > 1).value_counts()

    # choose one entry per cluster
    df = df.groupby("global_resulting_cluster_ID").head(n=1).copy()
    print(df.shape)

    # Note: we don't have v_mut or isotype for CoV-AbDab
    if "isotype_supergroup" not in df.columns:
        df["isotype_supergroup"] = "IGHG"
    if "v_mut" not in df.columns:
        df["v_mut"] = 0.0

    df["participant_label"] = interpretation.reference_dataset_name[gene_locus]
    df["specimen_label"] = interpretation.reference_dataset_name[gene_locus]
    df["disease"] = "Covid19"
    df["disease_subtype"] = "Covid19 - known binder"

    embedded = {}
    for fold_id in config.all_fold_ids:
        fold_df = df.copy()
        fold_df["participant_label"] += f"_{fold_id}"
        fold_df["specimen_label"] += f"_{fold_id}"

        # Make adata
        adata = apply_embedding.run_embedding_model(
            embedder=apply_embedding.load_embedding_model(
                gene_locus=gene_locus, fold_id=fold_id
            ),
            df=fold_df,
            gene_locus=gene_locus,
            fold_id=fold_id,
        )
        adata = apply_embedding.transform_embedded_anndata(
            transformations_to_apply=apply_embedding.load_transformations(
                gene_locus=gene_locus, fold_id=fold_id
            ),
            adata=adata,
        )

        embedded[fold_id] = adata
        print(fold_id, adata)

    joblib.dump(
        embedded,
        config.paths.scaled_anndatas_dir
        / gene_locus.name
        / "known_binders.embedded.in.all.folds.joblib",
    )

In [5]:
for gene_locus in config.gene_loci_used:
    process(gene_locus)

GeneLocus.BCR


(6844, 17)
(6781, 17)


2022-12-28 17:35:25,404 - absl - INFO - Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 


2022-12-28 17:35:29,267 - absl - INFO - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host Interpreter CUDA


2022-12-28 17:35:29,274 - absl - INFO - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'








2022-12-28 17:36:07,869 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





0 AnnData object with n_obs × n_vars = 6781 × 1900
    obs: 'CDRH3', 'j_gene', 'VHorVHH', 'Binds to', "Doesn't Bind to", 'Neutralising Vs', 'Not Neutralising Vs', 'Protein + Epitope', 'Origin', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_gene', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'isotype_supergroup', 'v_mut', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedded_fine_tuned_on_gene_locus'
    obsm: 'X_pca'


2022-12-28 17:36:23,823 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





1 AnnData object with n_obs × n_vars = 6781 × 1900
    obs: 'CDRH3', 'j_gene', 'VHorVHH', 'Binds to', "Doesn't Bind to", 'Neutralising Vs', 'Not Neutralising Vs', 'Protein + Epitope', 'Origin', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_gene', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'isotype_supergroup', 'v_mut', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedded_fine_tuned_on_gene_locus'
    obsm: 'X_pca'


2022-12-28 17:36:40,193 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





2 AnnData object with n_obs × n_vars = 6781 × 1900
    obs: 'CDRH3', 'j_gene', 'VHorVHH', 'Binds to', "Doesn't Bind to", 'Neutralising Vs', 'Not Neutralising Vs', 'Protein + Epitope', 'Origin', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_gene', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'isotype_supergroup', 'v_mut', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedded_fine_tuned_on_gene_locus'
    obsm: 'X_pca'


2022-12-28 17:36:56,512 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





-1 AnnData object with n_obs × n_vars = 6781 × 1900
    obs: 'CDRH3', 'j_gene', 'VHorVHH', 'Binds to', "Doesn't Bind to", 'Neutralising Vs', 'Not Neutralising Vs', 'Protein + Epitope', 'Origin', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_gene', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'isotype_supergroup', 'v_mut', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedded_fine_tuned_on_gene_locus'
    obsm: 'X_pca'


GeneLocus.TCR


  train_sequences_df = pd.read_csv(



(37591, 58)
(37591, 58)


2022-12-28 17:37:44,220 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:38:07,089 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:38:29,099 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:38:48,491 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





0 AnnData object with n_obs × n_vars = 37591 × 1900
    obs: 'rownum', 'TCR BioIdentity', 'TCR Nucleotide Sequence', 'Experiment', 'ORF Coverage', 'Amino Acids', 'Start Index in Genome', 'End Index in Genome', 'source', 'ORF', 'ORF Genebank ID', 'Amino Acid', 'Subject', 'Cell Type', 'Target Type', 'Cohort', 'Age', 'Gender', 'Race', 'HLA-A', 'HLA-A.1', 'HLA-B', 'HLA-B.1', 'HLA-C', 'HLA-C.1', 'DPA1', 'DPA1.1', 'DPB1', 'DPB1.1', 'DQA1', 'DQA1.1', 'DQB1', 'DQB1.1', 'DRB1', 'DRB1.1', 'DRB3', 'DRB3.1', 'DRB4', 'DRB4.1', 'DRB5', 'DRB5.1', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_segment', 'j_segment', 'productive', 'extracted_isotype', 'isotype_supergroup', 'v_gene', 'j_gene', 'cdr1_seq_aa_q', 'cdr2_seq_aa_q', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'v_mut', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedde

2022-12-28 17:39:05,877 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:39:16,117 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:39:26,309 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:39:35,855 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





1 AnnData object with n_obs × n_vars = 37591 × 1900
    obs: 'rownum', 'TCR BioIdentity', 'TCR Nucleotide Sequence', 'Experiment', 'ORF Coverage', 'Amino Acids', 'Start Index in Genome', 'End Index in Genome', 'source', 'ORF', 'ORF Genebank ID', 'Amino Acid', 'Subject', 'Cell Type', 'Target Type', 'Cohort', 'Age', 'Gender', 'Race', 'HLA-A', 'HLA-A.1', 'HLA-B', 'HLA-B.1', 'HLA-C', 'HLA-C.1', 'DPA1', 'DPA1.1', 'DPB1', 'DPB1.1', 'DQA1', 'DQA1.1', 'DQB1', 'DQB1.1', 'DRB1', 'DRB1.1', 'DRB3', 'DRB3.1', 'DRB4', 'DRB4.1', 'DRB5', 'DRB5.1', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_segment', 'j_segment', 'productive', 'extracted_isotype', 'isotype_supergroup', 'v_gene', 'j_gene', 'cdr1_seq_aa_q', 'cdr2_seq_aa_q', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'v_mut', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedde

2022-12-28 17:39:52,086 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:40:02,313 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:40:12,611 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:40:21,992 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





2 AnnData object with n_obs × n_vars = 37591 × 1900
    obs: 'rownum', 'TCR BioIdentity', 'TCR Nucleotide Sequence', 'Experiment', 'ORF Coverage', 'Amino Acids', 'Start Index in Genome', 'End Index in Genome', 'source', 'ORF', 'ORF Genebank ID', 'Amino Acid', 'Subject', 'Cell Type', 'Target Type', 'Cohort', 'Age', 'Gender', 'Race', 'HLA-A', 'HLA-A.1', 'HLA-B', 'HLA-B.1', 'HLA-C', 'HLA-C.1', 'DPA1', 'DPA1.1', 'DPB1', 'DPB1.1', 'DQA1', 'DQA1.1', 'DQB1', 'DQB1.1', 'DRB1', 'DRB1.1', 'DRB3', 'DRB3.1', 'DRB4', 'DRB4.1', 'DRB5', 'DRB5.1', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_segment', 'j_segment', 'productive', 'extracted_isotype', 'isotype_supergroup', 'v_gene', 'j_gene', 'cdr1_seq_aa_q', 'cdr2_seq_aa_q', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'v_mut', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedde

2022-12-28 17:40:38,215 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:40:48,713 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:40:59,027 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)


2022-12-28 17:41:08,891 - malid.embedders.unirep - INFO - Finished batch (unirep_fine_tuned)





-1 AnnData object with n_obs × n_vars = 37591 × 1900
    obs: 'rownum', 'TCR BioIdentity', 'TCR Nucleotide Sequence', 'Experiment', 'ORF Coverage', 'Amino Acids', 'Start Index in Genome', 'End Index in Genome', 'source', 'ORF', 'ORF Genebank ID', 'Amino Acid', 'Subject', 'Cell Type', 'Target Type', 'Cohort', 'Age', 'Gender', 'Race', 'HLA-A', 'HLA-A.1', 'HLA-B', 'HLA-B.1', 'HLA-C', 'HLA-C.1', 'DPA1', 'DPA1.1', 'DPB1', 'DPB1.1', 'DQA1', 'DQA1.1', 'DQB1', 'DQB1.1', 'DRB1', 'DRB1.1', 'DRB3', 'DRB3.1', 'DRB4', 'DRB4.1', 'DRB5', 'DRB5.1', 'cdr3_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'v_segment', 'j_segment', 'productive', 'extracted_isotype', 'isotype_supergroup', 'v_gene', 'j_gene', 'cdr1_seq_aa_q', 'cdr2_seq_aa_q', 'cdr1_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'v_mut', 'cluster_id_within_clustering_group', 'global_resulting_cluster_ID', 'num_clone_members', 'participant_label', 'specimen_label', 'disease', 'disease_subtype'
    uns: 'embedded', 'embedded_fine_tuned_on_fold_id', 'embedd