# Mal-ID Tutorial

Mal-ID uses B cell receptor (BCR) and T cell receptor (TCR) sequencing data to classify disease or immune state.

In the Mal-ID framework, we train three BCR and three TCR disease classifiers with different ways of extracting feature from immune receptor sequencing data.
Then we train an ensemble metamodel, a classifier that uses the predicted disease probabilities from the six base models to make a final prediction of disease status.

#### This tutorial covers:

1. Mal-ID configuration settings
2. Accessing sample metadata
3. Loading sequence data
4. Loading Models 1, 2, and 3 and making predictions with these base models
5. Loading the ensemble metamodel and making predictions

We will cover all of the components in this schematic:

<div>
<img src="../schematic.png" width="800"/>
</div>

This tutorial assumes you've already followed the "runbook": the commands in the readme to load the data and train the models.

**([Main repo here](https://github.com/maximz/malid))**

## Start with some necessary imports

In [1]:
import pandas as pd
from malid import config, io, helpers
from malid.trained_model_wrappers import (
    RepertoireClassifier,
    ConvergentClusterClassifier,
    VJGeneSpecificSequenceModelRollupClassifier,
    BlendingMetamodel,
)
from malid.datamodels import GeneLocus, TargetObsColumnEnum

## Review Mal-ID configuration settings

The inclusion criteria for the dataset — meaning which samples get divided into cross validation folds — are defined as a `CrossValidationSplitStrategy` object in `malid/datamodels.py`.

The default strategy is `CrossValidationSplitStrategy.in_house_peak_disease_timepoints`, which includes peak disease timepoints from our primary in-house dataset. Indeed, that's the active cross validation split strategy:

In [2]:
config.cross_validation_split_strategy

<CrossValidationSplitStrategy.in_house_peak_disease_timepoints: CrossValidationSplitStrategyValue(data_sources_keep=[<DataSource.in_house: 1>], stratify_by='disease', diseases_to_keep_all_subtypes=['Healthy/Background', 'HIV', 'Lupus', 'T1D'], subtypes_keep=['Covid19 - Sero-positive (ICU)', 'Covid19 - Sero-positive (Admit)', 'Covid19 - Acute 2', 'Covid19 - Admit', 'Covid19 - ICU', 'Influenza vaccine 2021 - day 7'], filter_specimens_func_by_study_name={'Covid19-buffycoat': <function acute_disease_choose_most_peak_timepoint at 0x7f1856a5edc0>, 'Covid19-Stanford': <function acute_disease_choose_most_peak_timepoint at 0x7f1856a5edc0>}, gene_loci_supported=<GeneLocus.BCR|TCR: 3>, exclude_study_names=['IBD pre-pandemic Yoni'], include_study_names=None, filter_out_specimens_funcs_global=[], study_names_for_held_out_set=None)>

The data is divided into three folds: `fold_id` can be 0, 1, or 2.

Each fold has a `train_smaller`, `validation`, and `test` set, referred to as a `fold_label`. (The `train_smaller` set is further subdivided into `train_smaller1` and `train_smaller2`.)

Each sample will be in one test set. We also make sure that all samples from the same person have the same `fold_label`.

Finally, we define a special `fold_id=-1` "global" fold that does not have a `test` set. All the data is instead used in the `train_smaller` and `validation` fold labels. (The `train_smaller` to `validation` proportion is the same as for other fold IDs, but both sets are larger than usual.)

<div>
<img src="../cross_validation.png" width="600"/>
</div>


The list of all fold IDs is therefore 0, 1, 2, and -1:

In [3]:
config.all_fold_ids

[0, 1, 2, -1]

The language model being used is ESM-2, applied to the CDR3 region:

In [4]:
config.embedder.name  # The name

'esm2_cdr3'

In [5]:
config.embedder.embedder_sequence_content  # The sequence region being embedded

<EmbedderSequenceContent.CDR3: 3>

Our models are configured to use both BCR and TCR data:

In [6]:
config.gene_loci_used

<GeneLocus.BCR|TCR: 3>

Just to clarify, `GeneLocus.BCR|TCR` means the union of BCR and TCR — both are active. This is the same as writing `GeneLocus.BCR | GeneLocus.TCR`:

In [7]:
GeneLocus.BCR | GeneLocus.TCR

<GeneLocus.BCR|TCR: 3>

## Load metadata

Here's a Pandas DataFrame with all of the samples in our database.

The key fields are:

- `participant_label`: the patient ID
- `specimen_label`: the sample ID
- `disease`

In [8]:
metadata = helpers.get_all_specimen_info()
metadata

Unnamed: 0,participant_label,specimen_label,disease,specimen_time_point,participant_description,data_source,study_name,available_gene_loci,disease_subtype,age,...,symptoms_cmv,symptoms_healthy_in_resequencing_experiment,specimen_time_point_days,survived_filters,is_selected_for_cv_strategy,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup,test_fold_id
0,BFI-0000234,M124-S014,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,27.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,1.0
1,BFI-0000234,M132-S014,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,27.0,...,,,,False,True,False,False,Healthy/Background,Healthy/Background,
2,BFI-0002850,M124-S042,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,26.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,0.0
3,BFI-0002850,M132-S040,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,26.0,...,,,,False,True,False,False,Healthy/Background,Healthy/Background,
4,BFI-0002851,M124-S041,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,27.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2540,ramesh-2015-ci_378,ramesh-2015-ci_378,CVID,,,DataSource.adaptive,ramesh-2015-ci,(((GeneLocus.TCR))),CVID,,...,,,,True,False,False,False,CVID,CVID,
2541,ramesh-2015-ci_386,ramesh-2015-ci_386,CVID,,,DataSource.adaptive,ramesh-2015-ci,(((GeneLocus.TCR))),CVID,,...,,,,True,False,False,False,CVID,CVID,
2542,ramesh-2015-ci_400,ramesh-2015-ci_400,CVID,,,DataSource.adaptive,ramesh-2015-ci,(((GeneLocus.TCR))),CVID,,...,,,,True,False,False,False,CVID,CVID,
2543,ramesh-2015-ci_441,ramesh-2015-ci_441,CVID,,,DataSource.adaptive,ramesh-2015-ci,(((GeneLocus.TCR))),CVID,,...,,,,True,False,False,False,CVID,CVID,


Each sample is identified by a `specimen_label` and has a boolean column named `in_training_set`, indicating whether a sample met the requirements for inclusion in the cross validation divisions and passed QC requirements (see the readme for more details).

Let's look at only the samples that passed those filters:

In [9]:
metadata = metadata[metadata["in_training_set"]]
metadata

Unnamed: 0,participant_label,specimen_label,disease,specimen_time_point,participant_description,data_source,study_name,available_gene_loci,disease_subtype,age,...,symptoms_cmv,symptoms_healthy_in_resequencing_experiment,specimen_time_point_days,survived_filters,is_selected_for_cv_strategy,in_training_set,past_exposure,disease.separate_past_exposures,disease.rollup,test_fold_id
0,BFI-0000234,M124-S014,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,27.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,1.0
2,BFI-0002850,M124-S042,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,26.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,0.0
4,BFI-0002851,M124-S041,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,27.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,2.0
6,BFI-0002852,M124-S012,Healthy/Background,,Location: USA,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,26.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,1.0
8,BFI-0002861,M124-S037,Healthy/Background,,Location: Malawi,DataSource.in_house,HIV,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",Healthy/Background - HIV Negative,34.0,...,,,,True,True,True,False,Healthy/Background,Healthy/Background,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2320,BFI-0010791,M491-S149,T1D,,Diabetes 35453-study biobank: pediatric; T1D +...,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,11.0,...,,,,True,True,True,False,T1D,T1D,0.0
2321,BFI-0010792,M491-S150,T1D,,Diabetes 35453-study biobank: adult; TID,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - adult,59.0,...,,,,True,True,True,False,T1D,T1D,1.0
2322,BFI-0010794,M491-S153,T1D,,Diabetes 35453-study biobank: pediatric; TID,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,16.0,...,,,,True,True,True,False,T1D,T1D,0.0
2323,BFI-0010800,M491-S160,T1D,,Diabetes 35453-study biobank: pediatric; T1D,DataSource.in_house,Diabetes biobank,"(((GeneLocus.BCR)), ((GeneLocus.TCR)))",T1D - pediatric,16.0,...,,,,True,True,True,False,T1D,T1D,0.0


The metadata also includes `disease_subtype`, `study_name`, `age`, `sex`, and `ethnicity_condensed` (ancestry) for a particular sample, as available.

## Load data

Let's choose a specific fold ID we will work with. As described above, our options are 0, 1, and 2 for the cross validation folds, and -1 for the global fold. Let's choose the first fold:

In [10]:
fold_id = 0

Classification targets are defined in an enumeration called `TargetObsColumnEnum`:

In [11]:
[t.name for t in TargetObsColumnEnum]

['disease',
 'disease_all_demographics_present',
 'ethnicity_condensed_healthy_only',
 'age_group_healthy_only',
 'age_group_binary_healthy_only',
 'age_group_pediatric_healthy_only',
 'sex_healthy_only',
 'covid_vs_healthy',
 'hiv_vs_healthy',
 't1d_vs_healthy',
 'lupus_vs_healthy',
 'flu_vs_healthy',
 'lupus_nephritis']

Each classification target is associated with a metadata field. Let's focus on `TargetObsColumnEnum.disease`, which is our main classification goal to predict the `disease` metadata column:

In [12]:
target_obs_column = TargetObsColumnEnum.disease

This target is associated with the "disease" metadata field. You can tell by looking at the `obs_column_name` attribute:

In [13]:
target_obs_column.value.obs_column_name

'disease'

Here are the values of that metadata field:

In [14]:
metadata["disease"].value_counts()

Healthy/Background    224
HIV                    98
Lupus                  98
T1D                    96
Covid19                63
Influenza              37
Name: disease, dtype: int64

At this point, the data have been split into cross-validation folds and the sequences have been transformed into language model embeddings.

BCR and TCR data can be loaded with separate calls to `io.load_fold_embeddings()`:

In [15]:
adata_bcr = io.load_fold_embeddings(
    fold_id=fold_id,
    fold_label="test",  # Load the held out data
    gene_locus=GeneLocus.BCR,
    target_obs_column=target_obs_column,
)
adata_bcr

{"message": "Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/embedded/esm2_cdr3/anndatas_scaled/BCR/fold.0.test.h5ad -> /srv/scratch/maximz/cache/3d20a5c23fcc35b23cb822073f9cb2d6b0bebaee62c3a334176a736d.0.test.h5ad", "time": "2024-08-07T13:45:27.002580"}


Only considering the two last: ['.test', '.h5ad'].
Only considering the two last: ['.test', '.h5ad'].


AnnData object with n_obs × n_vars = 5669053 × 640
    obs: 'amplification_label', 'v_gene', 'j_gene', 'disease', 'fr1_seq_aa_q_trim', 'cdr1_seq_aa_q_trim', 'fr2_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'fr3_seq_aa_q_trim', 'cdr3_seq_aa_q_trim', 'post_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'extracted_isotype', 'isotype_supergroup', 'v_mut', 'num_reads', 'igh_or_tcrb_clone_id', 'total_clone_num_reads', 'num_clone_members', 'specimen_label', 'past_exposure', 'disease.separate_past_exposures', 'disease.rollup', 'v_family', 'sample_weight_isotype_rebalance', 'sample_weight_clone_size', 'study_name', 'participant_label', 'specimen_time_point', 'disease_subtype', 'study_name_condensed', 'age', 'sex', 'ethnicity_condensed', 'age_group', 'age_group_binary', 'age_group_pediatric', 'disease_severity', 'specimen_description', 'symptoms_IBD_Disease duration (yrs)', 'symptoms_IBD_Disease for ≥15 years?', 'symptoms_IBD_Disease location', 'symptoms_IBD_Estraintestinal manifestations?', 'symptoms_I

In [16]:
adata_tcr = io.load_fold_embeddings(
    fold_id=fold_id,
    fold_label="test",
    gene_locus=GeneLocus.TCR,  # Same as previous code block, except now we are loading TCR data
    target_obs_column=TargetObsColumnEnum.disease,
)
adata_tcr

{"message": "Reading network file from local machine cache: /users/maximz/code/boyd-immune-repertoire-classification/data/data_v_20231027/in_house_peak_disease_timepoints/embedded/esm2_cdr3/anndatas_scaled/TCR/fold.0.test.h5ad -> /srv/scratch/maximz/cache/41935c853da5d38d5d39cd74a38bd148951d1e8ded05e20393494c35.0.test.h5ad", "time": "2024-08-07T13:46:41.518369"}


Only considering the two last: ['.test', '.h5ad'].
Only considering the two last: ['.test', '.h5ad'].


AnnData object with n_obs × n_vars = 8358079 × 640
    obs: 'amplification_label', 'v_gene', 'j_gene', 'disease', 'fr1_seq_aa_q_trim', 'cdr1_seq_aa_q_trim', 'fr2_seq_aa_q_trim', 'cdr2_seq_aa_q_trim', 'fr3_seq_aa_q_trim', 'cdr3_seq_aa_q_trim', 'post_seq_aa_q_trim', 'cdr3_aa_sequence_trim_len', 'extracted_isotype', 'isotype_supergroup', 'v_mut', 'num_reads', 'igh_or_tcrb_clone_id', 'total_clone_num_reads', 'num_clone_members', 'specimen_label', 'past_exposure', 'disease.separate_past_exposures', 'disease.rollup', 'v_family', 'sample_weight_isotype_rebalance', 'sample_weight_clone_size', 'study_name', 'participant_label', 'specimen_time_point', 'disease_subtype', 'study_name_condensed', 'age', 'sex', 'ethnicity_condensed', 'age_group', 'age_group_binary', 'age_group_pediatric', 'disease_severity', 'specimen_description', 'symptoms_IBD_Disease duration (yrs)', 'symptoms_IBD_Disease for ≥15 years?', 'symptoms_IBD_Disease location', 'symptoms_IBD_Estraintestinal manifestations?', 'symptoms_I

In Mal-ID, we store data in [AnnData containers](https://anndata.readthedocs.io/en/latest/). We often use the variable name `adata` for these objects.

AnnData containers have a `.X` property with the language model embedding vector for each sequence, and a `.obs` property with the metadata for each sequence:

In [17]:
adata_bcr.X

array([[-0.614  ,  0.1725 , -0.325  , ...,  0.512  , -0.7983 , -0.5854 ],
       [ 0.1841 ,  0.9688 ,  1.127  , ...,  0.6943 ,  0.0789 , -0.02266],
       [ 0.732  , -1.323  , -1.272  , ...,  0.0448 ,  0.6924 , -0.761  ],
       ...,
       [-1.221  ,  1.755  ,  0.1779 , ...,  0.4229 ,  1.312  , -0.4229 ],
       [ 0.9478 , -0.7085 , -1.136  , ..., -1.077  ,  0.1272 ,  0.1335 ],
       [-0.8794 ,  0.4502 ,  1.741  , ...,  0.906  , -1.046  , -0.3447 ]],
      dtype=float16)

In [18]:
adata_bcr.obs

Unnamed: 0_level_0,amplification_label,v_gene,j_gene,disease,fr1_seq_aa_q_trim,cdr1_seq_aa_q_trim,fr2_seq_aa_q_trim,cdr2_seq_aa_q_trim,fr3_seq_aa_q_trim,cdr3_seq_aa_q_trim,...,symptoms_Lupus_sm +/-,symptoms_Lupus_sm_nRNP +/-,symptoms_cmv,symptoms_healthy_in_resequencing_experiment,symptoms_healthy_old_vs_new_batch_identifier,isotype_proportion:IGHG,isotype_proportion:IGHA,isotype_proportion:IGHD-M,fold_id,fold_label
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
487,M111-S037_cDNA_PCR,IGHV3-20,IGHJ4,HIV,AAS,GFTFDDYG,MSWVRQAPGKGLEWVSG,LNWNGGST,GYADSVKGRFTISRDNAKNSLYLQMNSLRAEDTALYYC,ARDLLPGSYPTYYFEY,...,,,,,,0.177207,0.140070,0.682723,0,test
1,M111-S037_cDNA_PCR,IGHV1-69,IGHJ4,HIV,AS,GGTFSSYA,ISWVRQAPGQGLEWMGR,IIPILGIA,NYAQKFQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYC,AREDSFGAAVEY,...,,,,,,0.177207,0.140070,0.682723,0,test
15397,M111-S037_cDNA_PCR,IGHV6-1,IGHJ4,HIV,IS,GDNVSTNSAA,WNWIRQSPSRGLEWLGR,THFQSRWLY,DYAESVRGRITINPDTSKNQITLQLKSMTPDDTGIYYC,ARDQRYPKYYFDY,...,,,,,,0.177207,0.140070,0.682723,0,test
3,M111-S037_cDNA_PCR,IGHV3-7,IGHJ5,HIV,AAS,GFSFRGYW,MTWVRQAPGKGLEWVAN,IKQDGSET,YYVDSVNGRFTISRDNAKNSLYLQMNSLRAEDTAVYFC,ARSGTYMGFWDP,...,,,,,,0.177207,0.140070,0.682723,0,test
9787,M111-S037_cDNA_PCR,IGHV3-23,IGHJ6,HIV,ATS,GFTFSSCA,MSWVRQAPGKGLEWVSA,ISAGSSAT,YYADSVKGRFTISRDNSKNTLFLQMNSLRAEDTAVYYC,AKGGADCTTAACRGSYYYMDV,...,,,,,,0.177207,0.140070,0.682723,0,test
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21465-75,M491-S160_cDNA_PCR,IGHV3-74,IGHJ5,T1D,AAS,GFTFSNYW,MHWVRQAPGKGRVWGSR,LNSDERST,SYADSVKGRFTISRDNAKNTLYLQMNSLRVEDTAVYYC,TKAPVGSCNDASCYPLDL,...,,,,,,0.051986,0.130453,0.817561,0,test
21486-72,M491-S160_cDNA_PCR,IGHV3-23,IGHJ5,T1D,AAS,GFTFSSYV,MRWVRQAPGKGLEWVSA,ITGSGGST,YYADSVKGRFTISRDNSKKTVSLQMNSLRPEDTALYYC,VRFDSASSFDS,...,,,,,,0.051986,0.130453,0.817561,0,test
21534-71,M491-S160_cDNA_PCR,IGHV3-7,IGHJ5,T1D,AAS,GFTFSTYW,MSWVRQAPGKGLEWVAN,IKQDGSEK,YYVDSVKGRFTISRDNAKNSLYLQMNNLRPEDTAVYYC,VKVERFCGTSSCAPFDP,...,,,,,,0.051986,0.130453,0.817561,0,test
21548-68,M491-S160_cDNA_PCR,IGHV3-7,IGHJ5,T1D,AAS,GFTFSRYW,MNWVRQAPGKGLEWVAD,IKQDGREK,YYVDSVKGRFTISRDNAKNSLYLQMNSLRAEETAVYYC,ASQSDYYDSSGYYAWFAP,...,,,,,,0.051986,0.130453,0.817561,0,test


`adata_bcr.X` has one row per sequence and one column per language model embedding dimension. `adata_bcr.obs` has one row per sequence and one column per metadata field. The same is true for `adata_tcr`.

In [19]:
adata_bcr.X.shape  # samples x embedding dimensions

(5669053, 640)

In [20]:
adata_bcr.obs.shape  # samples x metadata fields

(5669053, 76)

Sequence-level metadata fields include:

- The `participant_label` (patient ID) and `specimen_label` (sample ID) from which the sequence originated
- `v_gene`: IGHV or TRBV gene segment name
- `j_gene`: IGHJ or TRBJ gene segment name
- `cdr3_seq_aa_q_trim`: CDR3 amino acid sequence
- `v_mut`: Somatic hypermutation rate, specifically the fraction of V region nucleotides that are mutated (BCR only)
- `isotype_supergroup`: either `IGHG`, `IGHA`, `IGHD-M` (combining IgD and IgM), or `TCRB`. The name "supergroup" refers to the fact that we are overloading the definition of isotypes by combining IgD and IgM, and by labeling all TCR data as `isotype_supergroup="TCRB"` for convenience, even though TCRs don't have isotypes by definition. (Note that IgE is filtered out in the subsampling step, as are unmutated IgD and IgM sequences which represent naive B cells.)

The choice of `target_obs_column` argument in the data loading `io.load_fold_embeddings()` call matters. The AnnData object is filtered down to samples that are "in scope" for a particular classification target. For example, if you choose `target_obs_column=TargetObsColumnEnum.sex_healthy_only`, the AnnDatas will only have samples from healthy individuals for whom the sex is known. (See `malid/datamodels.py` for the exact definition of each `TargetObsColumnEnum` option.)

> *Aside*
>
> When we ran `io.load_fold_embeddings()` above, you may have noticed some log messages about caching the data. When the large AnnData object is first loaded, it's copied from its original file path — usually on a network-mounted file system — to scratch storage on the machine where you're running this notebook. This speeds up further data loading.
>
> Another type of caching happens silently: the imported data is cached in memory, so that repeated calls to `io.load_fold_embeddings` with the same arguments don't result in slow loads of the same data over and over. The data is cached before we run the filtering for a particular classification target, so we can still leverage the cached version even if we switch the `target_obs_column` argument. This helps in our training loop, so we can train the model for many classification targets seamlessly. The cache is capped at four entries, and older data is automatically removed from the cache to make room for what's being used now. (You can also manually clear the cache with `io.clear_cached_fold_embeddings()`, or disable the cache altogether by setting the environment variable `MALID_DISABLE_IN_MEMORY_CACHE=true` before loading any Mal-ID Python code.)

## Model 1

Model 1 uses overall summary statistics of the BCR or TCR repertoire to predict disease status.

This way of generating features can be tied into many classification algorithms, e.g. logistic regression or random forests. We try a bunch of classification algorithms and choose the one with highest performance on the validation set, which is not seen during training. The same is true for Models 2 and 3.

We've recorded our choices of the chosen "model name" for Model 1, Model 2, and Model 3 in `config.metamodel_base_model_names`. The ensemble metamodel will use these versions of the base Mal-ID components:

In [21]:
config.metamodel_base_model_names

namespace(model_name_overall_repertoire_composition={<GeneLocus.BCR: 1>: 'elasticnet_cv0.25',
                                                     <GeneLocus.TCR: 2>: 'lasso_cv'},
          model_name_convergent_clustering={<GeneLocus.BCR: 1>: 'ridge_cv',
                                            <GeneLocus.TCR: 2>: 'lasso_cv'},
          base_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass',
                                    <GeneLocus.TCR: 2>: 'ridge_cv_ovr'},
          base_sequence_model_subset_strategy=<SequenceSubsetStrategy.split_Vgene_and_isotype: 3>,
          aggregation_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass_mean_aggregated_as_binary_ovr_reweighed_by_subset_frequencies',
                                           <GeneLocus.TCR: 2>: 'rf_multiclass_entropy_twenty_percent_cutoff_aggregated_as_binary_ovr_reweighed_by_subset_frequencies'})

Let's look specifically at the version of Model 1 chosen for BCR data:

In [22]:
config.metamodel_base_model_names.model_name_overall_repertoire_composition[
    GeneLocus.BCR
]

'elasticnet_cv0.25'

The model name is `elasticnet_cv0.25`, which is elastic net regularized logistic regression with an L1-L2 ratio of 0.25. (The exact definition is in `malid/train/model_definitions.py`.)

> _Aside_:
> >
> Models with `_cv` in the name use internal (nested) cross validation to tune their hyperparameters.

Let's load this version of Model 1 for BCR data:

In [23]:
clf1 = RepertoireClassifier(
    fold_id=fold_id,
    # Load "elasticnet_cv0.25"
    model_name=config.metamodel_base_model_names.model_name_overall_repertoire_composition[
        GeneLocus.BCR
    ],
    fold_label_train="train_smaller",  # Indicates which part of the data the model was trained on
    gene_locus=GeneLocus.BCR,  # A different model is trained for each sequencing locus
    target_obs_column=target_obs_column,  # A different model is trained for each classification target
)
clf1

We've loaded a `RepertoireClassifier`, which is a wrapper around a scikit-learn model stored in `_inner`.

In this case, it's a wrapper around a scikit-learn Pipeline:

In [24]:
type(clf1._inner)

sklearn.pipeline.Pipeline

In [25]:
clf1._inner

The scikit-learn Pipeline passes V-J gene use counts through `log1p`, scaling, and PCA transformations. This happens for IgG, IgA, and IgM separately, and is coordinated by a `ColumnTransformer` step.

The resulting PCs — along with the somatic hypermutation features marked "remainder" and "passthrough" — then go through standardization (`StandardScalerThatPreservesInputType`) and logistic regression (`GlmnetLogitNetWrapper`).


Here are the original input feature names. Notice how there's a feature for each V gene, J gene, and isotype combination here, which then gets reduced into a smaller set of features by PCA:

In [26]:
clf1.feature_names_in_  # Attribute access is passed through to the _inner scikit-learn Pipeline

array(['v_mut_median_per_specimen:IGHG', 'v_sequence_is_mutated:IGHG',
       'v_mut_median_per_specimen:IGHA', 'v_sequence_is_mutated:IGHA',
       'v_mut_median_per_specimen:IGHD-M', 'v_sequence_is_mutated:IGHD-M',
       'IGHG:pca_IGHV1-18|IGHJ1:IGHG', 'IGHG:pca_IGHV1-18|IGHJ2:IGHG',
       'IGHG:pca_IGHV1-18|IGHJ3:IGHG', 'IGHG:pca_IGHV1-18|IGHJ4:IGHG',
       'IGHG:pca_IGHV1-18|IGHJ5:IGHG', 'IGHG:pca_IGHV1-18|IGHJ6:IGHG',
       'IGHG:pca_IGHV1-24|IGHJ1:IGHG', 'IGHG:pca_IGHV1-24|IGHJ2:IGHG',
       'IGHG:pca_IGHV1-24|IGHJ3:IGHG', 'IGHG:pca_IGHV1-24|IGHJ4:IGHG',
       'IGHG:pca_IGHV1-24|IGHJ5:IGHG', 'IGHG:pca_IGHV1-24|IGHJ6:IGHG',
       'IGHG:pca_IGHV1-2|IGHJ1:IGHG', 'IGHG:pca_IGHV1-2|IGHJ2:IGHG',
       'IGHG:pca_IGHV1-2|IGHJ3:IGHG', 'IGHG:pca_IGHV1-2|IGHJ4:IGHG',
       'IGHG:pca_IGHV1-2|IGHJ5:IGHG', 'IGHG:pca_IGHV1-2|IGHJ6:IGHG',
       'IGHG:pca_IGHV1-3|IGHJ1:IGHG', 'IGHG:pca_IGHV1-3|IGHJ2:IGHG',
       'IGHG:pca_IGHV1-3|IGHJ3:IGHG', 'IGHG:pca_IGHV1-3|IGHJ4:IGHG',
       'IGHG

And now here is the reduced set of features coming out of the `ColumnTransformer` step. Notice how the V gene/J gene count features have turned into 15 PCs per isotype:

In [27]:
clf1.named_steps["columntransformer"].get_feature_names_out()

array(['log1p-scale-PCA_IGHG__pca0', 'log1p-scale-PCA_IGHG__pca1',
       'log1p-scale-PCA_IGHG__pca2', 'log1p-scale-PCA_IGHG__pca3',
       'log1p-scale-PCA_IGHG__pca4', 'log1p-scale-PCA_IGHG__pca5',
       'log1p-scale-PCA_IGHG__pca6', 'log1p-scale-PCA_IGHG__pca7',
       'log1p-scale-PCA_IGHG__pca8', 'log1p-scale-PCA_IGHG__pca9',
       'log1p-scale-PCA_IGHG__pca10', 'log1p-scale-PCA_IGHG__pca11',
       'log1p-scale-PCA_IGHG__pca12', 'log1p-scale-PCA_IGHG__pca13',
       'log1p-scale-PCA_IGHG__pca14', 'log1p-scale-PCA_IGHA__pca0',
       'log1p-scale-PCA_IGHA__pca1', 'log1p-scale-PCA_IGHA__pca2',
       'log1p-scale-PCA_IGHA__pca3', 'log1p-scale-PCA_IGHA__pca4',
       'log1p-scale-PCA_IGHA__pca5', 'log1p-scale-PCA_IGHA__pca6',
       'log1p-scale-PCA_IGHA__pca7', 'log1p-scale-PCA_IGHA__pca8',
       'log1p-scale-PCA_IGHA__pca9', 'log1p-scale-PCA_IGHA__pca10',
       'log1p-scale-PCA_IGHA__pca11', 'log1p-scale-PCA_IGHA__pca12',
       'log1p-scale-PCA_IGHA__pca13', 'log1p-scale-PCA

Here is the final step of the pipeline. As expected based on the `elasticnet_cv0.25` model name we specified when loading the trained model from disk, it's a elasticnet logistic regression with an 0.25 L1-L2 ratio:

In [28]:
clf1.steps[-1]

('glmnetlogitnetwrapper',
 GlmnetLogitNetWrapper: LogitNet(alpha=0.25, n_splits=5, random_state=0,
          scoring=<function GlmnetLogitNetWrapper.deviance_scorer at 0x7f1af8cb08b0>,
          standardize=False, verbose=True))

**Mal-ID models have the following API:**

- **`clf.featurize(adata)`**: this function accepts an AnnData object and generates features specific for the model. The features and metadata are returned in a `FeaturizedData` container, which we'll explore below. (The features themselves are in the `.X` attribute of the `FeaturizedData` container.)
- **`clf.predict_proba(features)`**: this function accepts features and returns predicted class probabilities by running the model.
- **`clf.predict(features)`**: this function accepts features and returns predicted class labels by running the model.

**_Common pattern:_** `predicted_class_probabilities = clf.predict_proba(clf.featurize(adata).X)`.

Let's walk through this with Model 1. First, let's generate features from the held-out BCR data:

In [29]:
featurized_model1_data = clf1.featurize(adata_bcr)
type(featurized_model1_data)

crosseval.featurized_data.FeaturizedData

We now have a `FeaturizedData` container. Let's unpack its contents:

In [30]:
# Features
featurized_model1_data.X

Unnamed: 0_level_0,v_mut_median_per_specimen:IGHG,v_sequence_is_mutated:IGHG,v_mut_median_per_specimen:IGHA,v_sequence_is_mutated:IGHA,v_mut_median_per_specimen:IGHD-M,v_sequence_is_mutated:IGHD-M,IGHG:pca_IGHV1-18|IGHJ1:IGHG,IGHG:pca_IGHV1-18|IGHJ2:IGHG,IGHG:pca_IGHV1-18|IGHJ3:IGHG,IGHG:pca_IGHV1-18|IGHJ4:IGHG,...,IGHD-M:pca_IGHV6-1|IGHJ3:IGHD-M,IGHD-M:pca_IGHV6-1|IGHJ4:IGHD-M,IGHD-M:pca_IGHV6-1|IGHJ5:IGHD-M,IGHD-M:pca_IGHV6-1|IGHJ6:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ1:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ2:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ3:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ4:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ5:IGHD-M,IGHD-M:pca_IGHV7-4-1|IGHJ6:IGHD-M
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
M111-S037,0.057778,0.911347,0.088496,0.979737,0.026549,1.0,0.001832,0.001603,0.005955,0.023362,...,0.000177,0.000473,0.000059,0.000059,0.000118,0.000118,0.000827,0.003251,0.000827,0.000709
M124-S042,0.079295,0.959693,0.083333,0.990588,0.030973,1.0,0.000850,0.000567,0.003967,0.020686,...,0.000527,0.002636,0.000527,0.000527,0.000000,0.000000,0.000000,0.000105,0.000000,0.000000
M111-S016,0.074890,0.951521,0.076233,0.982593,0.026549,1.0,0.000000,0.000305,0.007324,0.024413,...,0.001624,0.001662,0.001082,0.000889,0.000077,0.000077,0.000850,0.001855,0.000618,0.000541
M111-S043,0.047414,0.874591,0.066079,0.939325,0.029787,1.0,0.001112,0.000953,0.006989,0.015248,...,0.001325,0.002261,0.000780,0.000546,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
M124-S036,0.077586,0.966753,0.074236,0.966128,0.030303,1.0,0.002915,0.002624,0.006122,0.027697,...,0.000000,0.000085,0.000085,0.000085,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M491-S146,0.052863,0.837728,0.034632,0.691589,0.026667,1.0,0.000000,0.000000,0.002088,0.014614,...,0.000123,0.002330,0.000736,0.001103,0.000000,0.000000,0.000981,0.006253,0.003065,0.003188
M491-S148,0.057269,0.940083,0.070485,0.969115,0.026549,1.0,0.000000,0.001055,0.005274,0.021097,...,0.000660,0.001885,0.000377,0.000566,0.000094,0.000189,0.001225,0.006127,0.002168,0.002451
M491-S149,0.069565,0.977945,0.058296,0.929610,0.030568,1.0,0.000000,0.000962,0.005775,0.020693,...,0.000670,0.004019,0.002010,0.002345,0.000000,0.000000,0.000167,0.000000,0.000167,0.000000
M491-S153,0.065502,0.916667,0.070485,0.923849,0.026432,1.0,0.001368,0.000000,0.009576,0.019152,...,0.000553,0.001660,0.000830,0.000277,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [31]:
# Ground truth
featurized_model1_data.y

specimen_label
M111-S037                   HIV
M124-S042    Healthy/Background
M111-S016                   HIV
M111-S043                   HIV
M124-S036    Healthy/Background
                    ...        
M491-S146                   T1D
M491-S148                   T1D
M491-S149                   T1D
M491-S153                   T1D
M491-S160                   T1D
Name: disease, Length: 206, dtype: category
Categories (6, object): ['HIV', 'Healthy/Background', 'Lupus', 'Covid19', 'Influenza', 'T1D']

In [32]:
# Sample names
featurized_model1_data.sample_names

Index(['M111-S037', 'M124-S042', 'M111-S016', 'M111-S043', 'M124-S036',
       'M124-S040', 'M111-S035', 'M111-S025', 'M111-S011', 'M111-S036',
       ...
       'M491-S111', 'M491-S117', 'M491-S124', 'M491-S126', 'M491-S136',
       'M491-S146', 'M491-S148', 'M491-S149', 'M491-S153', 'M491-S160'],
      dtype='object', name='specimen_label', length=206)

In [33]:
# Sample metadata
featurized_model1_data.metadata

Unnamed: 0_level_0,age,disease,disease.rollup,disease.separate_past_exposures,disease_severity,disease_subtype,ethnicity_condensed,isotype_proportion:IGHA,isotype_proportion:IGHD-M,isotype_proportion:IGHG,participant_label,past_exposure,sex,study_name
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
M111-S037,48.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.140070,0.682723,0.177207,BFI-0000254,False,F,HIV
M124-S042,26.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - HIV Negative,,0.246533,0.547060,0.206407,BFI-0002850,False,F,HIV
M111-S016,22.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.101753,0.796830,0.101417,BFI-0002855,False,F,HIV
M111-S043,23.0,HIV,HIV,HIV,,HIV Non Neutralizing,African,0.228808,0.516170,0.255022,BFI-0002856,False,F,HIV
M124-S036,29.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - HIV Negative,African,0.144986,0.660305,0.194710,BFI-0002862,False,F,HIV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M491-S146,15.0,T1D,T1D,T1D,,T1D - pediatric,,0.098597,0.850927,0.050476,BFI-0010788,False,,Diabetes biobank
M491-S148,14.0,T1D,T1D,T1D,,T1D - pediatric,,0.110664,0.815813,0.073523,BFI-0010790,False,,Diabetes biobank
M491-S149,11.0,T1D,T1D,T1D,,T1D - pediatric,,0.136773,0.640226,0.223001,BFI-0010791,False,,Diabetes biobank
M491-S153,16.0,T1D,T1D,T1D,,T1D - pediatric,,0.160569,0.697251,0.142180,BFI-0010794,False,,Diabetes biobank


Now run the model to predict the per-class probabilities:

In [34]:
clf1.predict_proba(featurized_model1_data.X)

array([[1.20373269e-05, 9.93021939e-01, 6.68838158e-03, 4.56024978e-05,
        2.21748447e-04, 1.02907163e-05],
       [9.59384622e-04, 5.99845787e-03, 9.50646515e-01, 1.68303975e-03,
        3.85957328e-02, 2.11686958e-03],
       [8.50644976e-03, 7.52307114e-01, 1.39096021e-01, 4.49604811e-02,
        4.66933810e-02, 8.43655264e-03],
       ...,
       [1.17617906e-02, 1.45660594e-03, 1.84622816e-01, 1.39845976e-02,
        1.82523485e-01, 6.05650705e-01],
       [2.04910658e-02, 2.52044988e-03, 3.15463558e-01, 2.78819744e-03,
        6.38096684e-01, 2.06400456e-02],
       [1.04655981e-02, 1.80573281e-03, 3.49655098e-01, 1.45640859e-02,
        4.69149578e-04, 6.23040335e-01]])

To make this easier to read, let's bring in row and column names:

In [35]:
# Row names
featurized_model1_data.sample_names

Index(['M111-S037', 'M124-S042', 'M111-S016', 'M111-S043', 'M124-S036',
       'M124-S040', 'M111-S035', 'M111-S025', 'M111-S011', 'M111-S036',
       ...
       'M491-S111', 'M491-S117', 'M491-S124', 'M491-S126', 'M491-S136',
       'M491-S146', 'M491-S148', 'M491-S149', 'M491-S153', 'M491-S160'],
      dtype='object', name='specimen_label', length=206)

In [36]:
# Column names
clf1.classes_

array(['Covid19', 'HIV', 'Healthy/Background', 'Influenza', 'Lupus',
       'T1D'], dtype=object)

In [37]:
# Table of predicted class probabilities
pd.DataFrame(
    clf1.predict_proba(featurized_model1_data.X),
    index=featurized_model1_data.sample_names,
    columns=clf1.classes_,
)

Unnamed: 0_level_0,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
M111-S037,0.000012,9.930219e-01,0.006688,4.560250e-05,0.000222,0.000010
M124-S042,0.000959,5.998458e-03,0.950647,1.683040e-03,0.038596,0.002117
M111-S016,0.008506,7.523071e-01,0.139096,4.496048e-02,0.046693,0.008437
M111-S043,0.000105,9.961163e-01,0.003342,4.690368e-05,0.000294,0.000096
M124-S036,0.000226,2.429462e-01,0.704671,5.227891e-03,0.041741,0.005189
...,...,...,...,...,...,...
M491-S146,0.000003,1.615749e-07,0.000527,1.550379e-07,0.000238,0.999232
M491-S148,0.099963,4.009954e-02,0.477593,8.308356e-02,0.177256,0.122004
M491-S149,0.011762,1.456606e-03,0.184623,1.398460e-02,0.182523,0.605651
M491-S153,0.020491,2.520450e-03,0.315464,2.788197e-03,0.638097,0.020640


**We just generated a table of predicted class probabilities for Model 1, using `featurize()` and `predict_proba()`.**

Alternatively, we can use the model to generate a single predicted label for each sample:

In [38]:
clf1.predict(featurized_model1_data.X)

array(['HIV', 'Healthy/Background', 'HIV', 'HIV', 'Healthy/Background',
       'Healthy/Background', 'HIV', 'HIV', 'HIV', 'HIV',
       'Healthy/Background', 'T1D', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'T1D', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'Lupus', 'Healthy/Background', 'Healthy/Background', 'Covid19',
       'Healthy/Background', 'Healthy/Background', 'Covid19', 'HIV',
       'HIV', 'HIV', 'HIV', 'HI

## Model 2

Model 2 uses clustering to identify shared groups of sequences across individuals with the same disease. Then we predict disease using the number of disease-associated cluster hits per sample.

As before, let's first check which version is used in the ensemble metamodel. It's the ridge logistic regression version of Model 2:

In [39]:
# Chosen model names for Models 1, 2, and 3
config.metamodel_base_model_names

namespace(model_name_overall_repertoire_composition={<GeneLocus.BCR: 1>: 'elasticnet_cv0.25',
                                                     <GeneLocus.TCR: 2>: 'lasso_cv'},
          model_name_convergent_clustering={<GeneLocus.BCR: 1>: 'ridge_cv',
                                            <GeneLocus.TCR: 2>: 'lasso_cv'},
          base_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass',
                                    <GeneLocus.TCR: 2>: 'ridge_cv_ovr'},
          base_sequence_model_subset_strategy=<SequenceSubsetStrategy.split_Vgene_and_isotype: 3>,
          aggregation_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass_mean_aggregated_as_binary_ovr_reweighed_by_subset_frequencies',
                                           <GeneLocus.TCR: 2>: 'rf_multiclass_entropy_twenty_percent_cutoff_aggregated_as_binary_ovr_reweighed_by_subset_frequencies'})

In [40]:
# The model name chosen for Model 2 - BCR
config.metamodel_base_model_names.model_name_convergent_clustering[GeneLocus.BCR]

'ridge_cv'

Now let's load that version:

In [41]:
clf2 = ConvergentClusterClassifier(
    fold_id=fold_id,
    model_name=config.metamodel_base_model_names.model_name_convergent_clustering[
        GeneLocus.BCR
    ],
    fold_label_train="train_smaller1",  # The model was trained on train_smaller1, with hyperparameter tuning on train_smaller2.
    gene_locus=GeneLocus.BCR,
    target_obs_column=target_obs_column,
)
clf2

We get a `ConvergentClusterClassifier` object.

Just as with `RepertoireClassifier` for Model 1, `ConvergentClusterClassifier` is a wrapper around a scikit-learn Pipeline:

In [42]:
type(clf2._inner)

sklearn.pipeline.Pipeline

But this time, the Pipeline is simpler:

In [43]:
clf2._inner

What's happening here: The pipeline confirms that the expected feature names are present, it standardizes the features, and then it runs logistic regression.

To run the model, let's start by featurizing the dataset, just as we did with Model 1:

In [44]:
featurized_model2_data = clf2.featurize(adata_bcr)
type(featurized_model2_data)

crosseval.featurized_data.FeaturizedData

The `FeaturizedData` container can be unpacked the same way:

In [45]:
# Features (disease-associated cluster hit counts)
featurized_model2_data.X

cluster_dominant_label,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
M111-S037,0.0,13.0,1.0,1.0,11.0,0.0
M124-S042,0.0,0.0,2.0,4.0,5.0,0.0
M111-S016,1.0,19.0,1.0,2.0,8.0,0.0
M111-S043,0.0,24.0,1.0,0.0,4.0,0.0
M124-S036,0.0,5.0,1.0,2.0,3.0,0.0
...,...,...,...,...,...,...
M491-S146,0.0,3.0,1.0,0.0,3.0,0.0
M491-S148,0.0,0.0,2.0,1.0,2.0,0.0
M491-S149,0.0,1.0,4.0,5.0,1.0,1.0
M491-S153,0.0,1.0,1.0,0.0,5.0,1.0


In [46]:
# Ground truth
featurized_model2_data.y

specimen_label
M111-S037                   HIV
M124-S042    Healthy/Background
M111-S016                   HIV
M111-S043                   HIV
M124-S036    Healthy/Background
                    ...        
M491-S146                   T1D
M491-S148                   T1D
M491-S149                   T1D
M491-S153                   T1D
M491-S160                   T1D
Name: disease, Length: 205, dtype: category
Categories (6, object): ['HIV', 'Healthy/Background', 'Lupus', 'Covid19', 'Influenza', 'T1D']

In [47]:
# Sample names
featurized_model2_data.sample_names

Index(['M111-S037', 'M124-S042', 'M111-S016', 'M111-S043', 'M124-S036',
       'M124-S040', 'M111-S035', 'M111-S025', 'M111-S011', 'M111-S036',
       ...
       'M491-S111', 'M491-S117', 'M491-S124', 'M491-S126', 'M491-S136',
       'M491-S146', 'M491-S148', 'M491-S149', 'M491-S153', 'M491-S160'],
      dtype='object', name='specimen_label', length=205)

In [48]:
# Sample metadata
featurized_model2_data.metadata

Unnamed: 0_level_0,age,disease,disease.rollup,disease.separate_past_exposures,disease_severity,disease_subtype,ethnicity_condensed,isotype_proportion:IGHA,isotype_proportion:IGHD-M,isotype_proportion:IGHG,participant_label,past_exposure,sex,study_name
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
M111-S037,48.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.140070,0.682723,0.177207,BFI-0000254,False,F,HIV
M124-S042,26.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - HIV Negative,,0.246533,0.547060,0.206407,BFI-0002850,False,F,HIV
M111-S016,22.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.101753,0.796830,0.101417,BFI-0002855,False,F,HIV
M111-S043,23.0,HIV,HIV,HIV,,HIV Non Neutralizing,African,0.228808,0.516170,0.255022,BFI-0002856,False,F,HIV
M124-S036,29.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - HIV Negative,African,0.144986,0.660305,0.194710,BFI-0002862,False,F,HIV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M491-S146,15.0,T1D,T1D,T1D,,T1D - pediatric,,0.098597,0.850927,0.050476,BFI-0010788,False,,Diabetes biobank
M491-S148,14.0,T1D,T1D,T1D,,T1D - pediatric,,0.110664,0.815813,0.073523,BFI-0010790,False,,Diabetes biobank
M491-S149,11.0,T1D,T1D,T1D,,T1D - pediatric,,0.136773,0.640226,0.223001,BFI-0010791,False,,Diabetes biobank
M491-S153,16.0,T1D,T1D,T1D,,T1D - pediatric,,0.160569,0.697251,0.142180,BFI-0010794,False,,Diabetes biobank


But this time, the `FeaturizedData` container also has some extra fields. Unlike the other components of Mal-ID, Model 2 abstains from prediction if none of the sequences in a sample match any disease-associated clusters. These abstentions are also stored in the `FeaturizedData` container too:

In [49]:
# These samples had no features generated — they are abstentions:
featurized_model2_data.abstained_sample_names

Index(['M418-S118'], dtype='object', name='specimen_label')

In [50]:
# Ground truth for abstained samples
featurized_model2_data.abstained_sample_y

specimen_label
M418-S118    Covid19
Name: disease, dtype: category
Categories (6, object): ['HIV', 'Healthy/Background', 'Lupus', 'Covid19', 'Influenza', 'T1D']

In [51]:
# Metadata for abstained samples
featurized_model2_data.abstained_sample_metadata

Unnamed: 0_level_0,age,disease,disease.rollup,disease.separate_past_exposures,disease_severity,disease_subtype,ethnicity_condensed,isotype_proportion:IGHA,isotype_proportion:IGHD-M,isotype_proportion:IGHG,participant_label,past_exposure,sex,study_name
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
M418-S118,68.0,Covid19,Covid19,Covid19,ICU,Covid19 - ICU,Hispanic/Latino,0.497761,0.294776,0.207463,BFI-0009059,False,M,Covid19-Stanford


Let's run the model to predict the per-class probabilities for the samples that are not abstentions:

In [52]:
# Table of predicted class probabilities
pd.DataFrame(
    clf2.predict_proba(featurized_model2_data.X),
    index=featurized_model2_data.sample_names,
    columns=clf2.classes_,
)

Unnamed: 0_level_0,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
M111-S037,0.093203,0.231276,0.203522,0.038482,0.261447,0.172070
M124-S042,0.110002,0.085210,0.311180,0.051944,0.202014,0.239650
M111-S016,0.094130,0.355081,0.177844,0.039048,0.189403,0.144494
M111-S043,0.072264,0.485880,0.155931,0.028994,0.128728,0.128202
M124-S036,0.116146,0.138989,0.255802,0.046437,0.190750,0.251876
...,...,...,...,...,...,...
M491-S146,0.118626,0.117655,0.262798,0.039239,0.195331,0.266350
M491-S148,0.113949,0.087438,0.323998,0.040867,0.172615,0.261133
M491-S149,0.090366,0.056158,0.267491,0.045139,0.087322,0.453525
M491-S153,0.101486,0.059672,0.144173,0.031095,0.134344,0.529229


The `clf2.predict_proba(clf2.featurize(adata_bcr).X)` pattern is the same as what we saw above for Model 1.

## Model 3

Model 3 uses language model embeddings of BCR and TCR sequences to predict disease state, in two stages:

* **Sequence stage**: Predict which type of patient a sequence comes from, based on the language model embedding for that sequence. _This is a sequence-level model._ It's actually trained separately for each V gene and isotype.
* **Aggregation stage**: Predict the disease status for an entire sample, based on the per-sequence predictions from the sequence stage. _This is a sample-level or person-level model._

Each stage is trained separately. The sequence stage is trained on the `train_smaller1` fold, and the aggregation stage is trained on the `train_smaller2` fold.

The selected model names are recorded in `config.metamodel_base_model_names` as `base_sequence_model_name` and `aggregation_sequence_model_name` for the two stages, respectively:

In [53]:
config.metamodel_base_model_names

namespace(model_name_overall_repertoire_composition={<GeneLocus.BCR: 1>: 'elasticnet_cv0.25',
                                                     <GeneLocus.TCR: 2>: 'lasso_cv'},
          model_name_convergent_clustering={<GeneLocus.BCR: 1>: 'ridge_cv',
                                            <GeneLocus.TCR: 2>: 'lasso_cv'},
          base_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass',
                                    <GeneLocus.TCR: 2>: 'ridge_cv_ovr'},
          base_sequence_model_subset_strategy=<SequenceSubsetStrategy.split_Vgene_and_isotype: 3>,
          aggregation_sequence_model_name={<GeneLocus.BCR: 1>: 'rf_multiclass_mean_aggregated_as_binary_ovr_reweighed_by_subset_frequencies',
                                           <GeneLocus.TCR: 2>: 'rf_multiclass_entropy_twenty_percent_cutoff_aggregated_as_binary_ovr_reweighed_by_subset_frequencies'})

Let's load the selected versions of both stages.

Both stages can be loaded together through the `VJGeneSpecificSequenceModelRollupClassifier` class, which represents the aggregation stage:

In [54]:
clf3 = VJGeneSpecificSequenceModelRollupClassifier(
    # First, all the usual parameters, like fold ID, sequencing locus, and classification target:
    fold_id=fold_id,
    gene_locus=GeneLocus.BCR,
    target_obs_column=target_obs_column,
    #
    # Model 3 includes a seqeunce stage and an aggregation stage.
    # The aggregation stage is trained on top of the sequence stage, so to speak.
    # First, provide the sequence stage model name:
    base_sequence_model_name=config.metamodel_base_model_names.base_sequence_model_name[
        GeneLocus.BCR
    ],
    # The sequence stage was trained on train_smaller1:
    base_model_train_fold_label="train_smaller1",
    #
    # Next, provide the aggregation stage model name here:
    rollup_model_name=config.metamodel_base_model_names.aggregation_sequence_model_name[
        GeneLocus.BCR
    ],
    # The aggregation stage was trained on train_smaller2:
    fold_label_train="train_smaller2",
)
clf3

The sequence stage is automatically loaded and stored inside of the `VJGeneSpecificSequenceModelRollupClassifier`:

In [55]:
clf3.sequence_classifier

<malid.trained_model_wrappers.vj_gene_specific_sequence_classifier.VGeneIsotypeSpecificSequenceClassifier at 0x7f09ec36cdf0>

The sequence stage, a `VGeneIsotypeSpecificSequenceClassifier` object, is actually a collection of models trained separately for each V gene and isotype:

In [56]:
clf3.sequence_classifier.models_

{('IGHV1-2',
  'IGHA'): Pipeline(steps=[('standardscalerthatpreservesinputtype',
                  StandardScalerThatPreservesInputType()),
                 ('randomforestclassifier',
                  RandomForestClassifier(class_weight='balanced_subsample',
                                         n_jobs=2, random_state=0))]),
 ('IGHV3-21',
  'IGHD-M'): Pipeline(steps=[('standardscalerthatpreservesinputtype',
                  StandardScalerThatPreservesInputType()),
                 ('randomforestclassifier',
                  RandomForestClassifier(class_weight='balanced_subsample',
                                         n_jobs=2, random_state=0))]),
 ('IGHV4-34',
  'IGHG'): Pipeline(steps=[('standardscalerthatpreservesinputtype',
                  StandardScalerThatPreservesInputType()),
                 ('randomforestclassifier',
                  RandomForestClassifier(class_weight='balanced_subsample',
                                         n_jobs=2, random_state=0))]),
 ('

Returning to the aggregation stage, it accepts the following features:

In [57]:
clf3.feature_names_in_

['Influenza_IGHV3-23_IGHA',
 'Influenza_IGHV3-23_IGHD-M',
 'Influenza_IGHV3-23_IGHG',
 'Influenza_IGHV4-b_IGHA',
 'Influenza_IGHV4-b_IGHD-M',
 'Influenza_IGHV4-b_IGHG',
 'Influenza_IGHV3-7_IGHA',
 'Influenza_IGHV3-7_IGHD-M',
 'Influenza_IGHV3-7_IGHG',
 'Influenza_IGHV1-18_IGHA',
 'Influenza_IGHV1-18_IGHD-M',
 'Influenza_IGHV1-18_IGHG',
 'Influenza_IGHV4-59_IGHA',
 'Influenza_IGHV4-59_IGHD-M',
 'Influenza_IGHV4-59_IGHG',
 'Influenza_IGHV4-61_IGHA',
 'Influenza_IGHV4-61_IGHD-M',
 'Influenza_IGHV4-61_IGHG',
 'Influenza_IGHV1-69_IGHA',
 'Influenza_IGHV1-69_IGHD-M',
 'Influenza_IGHV1-69_IGHG',
 'Influenza_IGHV3-21_IGHA',
 'Influenza_IGHV3-21_IGHD-M',
 'Influenza_IGHV3-21_IGHG',
 'Influenza_IGHV3-48_IGHA',
 'Influenza_IGHV3-48_IGHD-M',
 'Influenza_IGHV3-48_IGHG',
 'Influenza_IGHV1-46_IGHA',
 'Influenza_IGHV1-46_IGHD-M',
 'Influenza_IGHV1-46_IGHG',
 'Influenza_IGHV4-39_IGHA',
 'Influenza_IGHV4-39_IGHD-M',
 'Influenza_IGHV4-39_IGHG',
 'Influenza_IGHV5-51_IGHA',
 'Influenza_IGHV5-51_IGHD-M',
 '

**How do we get features like `Influenza_IGHV3-23_IGHA`, which represents the average probability of the Influenza class across the IGHV3-23, IgA sequences in each sample?**

First, a vector of per-class predicted probabilities is generated for each sequence, using the model associated with the V gene and isotype the sequence belongs to.

Then the probabilities are aggregated across sequences from the same sample. Probabilities are only comparable between sequences scored by the same model, so the aggregation happens separately for each V gene and isotype group. For BCR, the aggregation strategy used is just an average:

In [58]:
clf3.aggregation_strategy

<AggregationStrategy.mean: 1>

Once we have those features, the "aggregation stage" involves running this scikit-learn Pipeline:

In [59]:
clf3._inner

You may recognize the first two steps of the Pipeline from before:

* `MatchVariables` confirms the expected feature names are present.
* `StandardScalerThatPreservesInputType` standardizes the features.

But the third step, `BinaryOvRClassifierWithFeatureSubsettingByClass`, is new.

What's happening here is that the aggregation stage model is fitted in a one-versus-rest fashion, with one model for each class (e.g. Covid-19 vs rest):

In [60]:
clf3.named_steps["binaryovrclassifierwithfeaturesubsettingbyclass"].estimators_

[InnerEstimator(clf=RandomForestClassifier(class_weight='balanced_subsample', n_jobs=1,
                        random_state=0), negative_class='not Covid19', positive_class='Covid19'),
 InnerEstimator(clf=RandomForestClassifier(class_weight='balanced_subsample', n_jobs=1,
                        random_state=0), negative_class='not HIV', positive_class='HIV'),
 InnerEstimator(clf=RandomForestClassifier(class_weight='balanced_subsample', n_jobs=1,
                        random_state=0), negative_class='not Healthy/Background', positive_class='Healthy/Background'),
 InnerEstimator(clf=RandomForestClassifier(class_weight='balanced_subsample', n_jobs=1,
                        random_state=0), negative_class='not Influenza', positive_class='Influenza'),
 InnerEstimator(clf=RandomForestClassifier(class_weight='balanced_subsample', n_jobs=1,
                        random_state=0), negative_class='not Lupus', positive_class='Lupus'),
 InnerEstimator(clf=RandomForestClassifier(class_weight=

Additionally, the submodel for each class is trained _only with features corresponding to that class_.

For example, the sequence-level model generates predicted class probabilities `P(Covid-19)`, `P(HIV)`, `P(Lupus)`, and so forth for every sequence. But when making predictions of Covid-19 using the "Covid-19 vs rest" model, we should only look at the `P(Covid-19)` values. Sequence-level probabilities for the other classes, like `P(HIV)`, should have no bearing.

Let's confirm that only Covid-19 specific features enter the "Covid-19 vs rest" model. All of these are `P(Covid-19)` features:

In [61]:
clf3.named_steps["binaryovrclassifierwithfeaturesubsettingbyclass"].estimators_[
    0
].clf.feature_names_in_

array(['Covid19_IGHV3-23_IGHA', 'Covid19_IGHV3-23_IGHD-M',
       'Covid19_IGHV3-23_IGHG', 'Covid19_IGHV4-b_IGHA',
       'Covid19_IGHV4-b_IGHD-M', 'Covid19_IGHV4-b_IGHG',
       'Covid19_IGHV3-7_IGHA', 'Covid19_IGHV3-7_IGHD-M',
       'Covid19_IGHV3-7_IGHG', 'Covid19_IGHV1-18_IGHA',
       'Covid19_IGHV1-18_IGHD-M', 'Covid19_IGHV1-18_IGHG',
       'Covid19_IGHV4-59_IGHA', 'Covid19_IGHV4-59_IGHD-M',
       'Covid19_IGHV4-59_IGHG', 'Covid19_IGHV4-61_IGHA',
       'Covid19_IGHV4-61_IGHD-M', 'Covid19_IGHV4-61_IGHG',
       'Covid19_IGHV1-69_IGHA', 'Covid19_IGHV1-69_IGHD-M',
       'Covid19_IGHV1-69_IGHG', 'Covid19_IGHV3-21_IGHA',
       'Covid19_IGHV3-21_IGHD-M', 'Covid19_IGHV3-21_IGHG',
       'Covid19_IGHV3-48_IGHA', 'Covid19_IGHV3-48_IGHD-M',
       'Covid19_IGHV3-48_IGHG', 'Covid19_IGHV1-46_IGHA',
       'Covid19_IGHV1-46_IGHD-M', 'Covid19_IGHV1-46_IGHG',
       'Covid19_IGHV4-39_IGHA', 'Covid19_IGHV4-39_IGHD-M',
       'Covid19_IGHV4-39_IGHG', 'Covid19_IGHV5-51_IGHA',
       'Covid19

Now that we have Model 3 loaded, let's featurize the dataset to be able to use the model.

This means every sequence is run through its associated sequence-level classifier, then the sequence probabilities are aggregated into sample-level features:

In [62]:
featurized_model3_data = clf3.featurize(adata_bcr)
type(featurized_model3_data)

{"message": "Number of VJGeneSpecificSequenceModelRollupClassifier featurization matrix N/As due to specimens not having any sequences with particular V/J gene pairs: 8538 / 140904 = 6.06%", "time": "2024-08-07T13:54:29.165469"}


malid.trained_model_wrappers.vj_gene_specific_sequence_model_rollup_classifier.SubsetRollupClassifierFeaturizedData

> _Aside:_
>
> The log message about "VJGeneSpecificSequenceModelRollupClassifier featurization matrix N/As due to specimens not having any sequences with particular V/J gene pairs" is a bit misleading.
>
> What it's actually referring to is missing values due to some samples having no sequences to score and aggregate for a particular V gene and isotype. For these samples, the `P(Covid-19)`, `P(HIV)`, `P(Lupus)`, and so on are set to the value `1 / n_classes` for the V gene and isotype combination that are missing.

Let's review the features that were generated:

In [63]:
# Features (in this case, they've already been standardized)
featurized_model3_data.X

Unnamed: 0_level_0,Influenza_IGHV3-23_IGHA,Influenza_IGHV3-23_IGHD-M,Influenza_IGHV3-23_IGHG,Influenza_IGHV4-b_IGHA,Influenza_IGHV4-b_IGHD-M,Influenza_IGHV4-b_IGHG,Influenza_IGHV3-7_IGHA,Influenza_IGHV3-7_IGHD-M,Influenza_IGHV3-7_IGHG,Influenza_IGHV1-18_IGHA,...,Covid19_IGHV4-30-4_IGHG,Covid19_IGHV1-3_IGHA,Covid19_IGHV1-3_IGHD-M,Covid19_IGHV1-3_IGHG,Covid19_IGHV5-a_IGHA,Covid19_IGHV5-a_IGHD-M,Covid19_IGHV5-a_IGHG,Covid19_IGHV7-4-1_IGHA,Covid19_IGHV7-4-1_IGHD-M,Covid19_IGHV7-4-1_IGHG
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
M111-S003,0.010139,-0.128048,-0.010928,-0.018978,-0.020769,-0.029884,-0.026154,-0.038170,-0.014322,-0.005550,...,-0.000056,0.000000,0.000068,0.000000,-0.003164,-0.002320,-0.002893,0.000000,0.000000,0.000000
M111-S007,0.032499,-0.014721,0.057095,-0.008114,-0.010291,-0.009460,0.037804,-0.012315,0.035763,0.013249,...,-0.000180,0.000000,-0.000212,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
M111-S009,-0.030703,-0.054979,-0.021037,-0.000240,-0.000024,-0.000127,0.005277,-0.030897,-0.015195,-0.003173,...,-0.000830,-0.004305,-0.001096,-0.001610,-0.006311,-0.001547,-0.001782,0.000000,-0.000266,0.000000
M111-S011,-0.083829,-0.054190,-0.017302,-0.035306,-0.037384,-0.019191,-0.079265,-0.045720,-0.011754,-0.080297,...,-0.006675,-0.002341,-0.002151,-0.006111,-0.002172,-0.005744,-0.006852,-0.004383,-0.002286,-0.013172
M111-S016,0.125323,0.033698,0.045521,0.000000,0.000000,-0.000200,0.091379,-0.008926,0.018937,-0.044331,...,-0.000388,-0.002148,-0.002241,-0.000831,0.000000,0.000000,0.000000,-0.005406,-0.002961,-0.007080
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M64-107,-0.034607,-0.149861,0.124894,-0.006630,-0.012994,-0.009094,0.018204,-0.034933,-0.062698,0.014481,...,0.000000,-0.003967,-0.001134,-0.004544,0.000000,-0.000073,0.000000,0.000000,-0.000117,0.000000
M64-109,-0.003658,-0.124797,0.051024,-0.010079,-0.013158,-0.005915,-0.005914,-0.025667,0.056407,-0.077331,...,-0.028257,-0.002327,-0.003638,-0.016851,-0.015336,-0.006743,-0.011526,-0.000539,-0.000074,0.000000
M64-111,0.128735,-0.082885,0.007388,-0.008783,-0.007355,-0.000910,0.075859,-0.047132,-0.014003,0.074678,...,-0.004160,0.000058,0.002157,-0.016897,-0.008180,-0.004634,-0.025562,-0.001787,-0.001867,-0.005862
M64-112,-0.233664,-0.009533,0.016623,0.000000,-0.000212,-0.001248,-0.065664,-0.018313,0.002118,-0.057121,...,-0.000634,-0.005376,-0.001812,-0.011419,-0.008069,-0.002266,-0.010332,-0.009078,-0.002326,-0.019100


In [64]:
# Ground truth
featurized_model3_data.y

specimen_label
M111-S003                   HIV
M111-S007                   HIV
M111-S009                   HIV
M111-S011                   HIV
M111-S016                   HIV
                    ...        
M64-107      Healthy/Background
M64-109      Healthy/Background
M64-111      Healthy/Background
M64-112      Healthy/Background
M64-114      Healthy/Background
Name: disease, Length: 206, dtype: category
Categories (6, object): ['HIV', 'Healthy/Background', 'Lupus', 'Covid19', 'Influenza', 'T1D']

In [65]:
# Sample names
featurized_model3_data.sample_names

Index(['M111-S003', 'M111-S007', 'M111-S009', 'M111-S011', 'M111-S016',
       'M111-S020', 'M111-S021', 'M111-S023', 'M111-S025', 'M111-S028',
       ...
       'M64-086', 'M64-089', 'M64-090', 'M64-101', 'M64-103', 'M64-107',
       'M64-109', 'M64-111', 'M64-112', 'M64-114'],
      dtype='object', name='specimen_label', length=206)

In [66]:
# Sample metadata
featurized_model3_data.metadata

Unnamed: 0_level_0,age,disease,disease.rollup,disease.separate_past_exposures,disease_severity,disease_subtype,ethnicity_condensed,isotype_proportion:IGHA,isotype_proportion:IGHD-M,isotype_proportion:IGHG,participant_label,past_exposure,sex,study_name
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
M111-S003,49.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.120568,0.741955,0.137476,BFI-0003466,False,M,HIV
M111-S007,48.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.114607,0.709431,0.175962,BFI-0003469,False,M,HIV
M111-S009,47.0,HIV,HIV,HIV,,HIV Broad Neutralizing,,0.174015,0.644778,0.181207,BFI-0003453,False,M,HIV
M111-S011,24.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.056556,0.721739,0.221704,BFI-0002875,False,F,HIV
M111-S016,22.0,HIV,HIV,HIV,,HIV Broad Neutralizing,African,0.101753,0.796830,0.101417,BFI-0002855,False,F,HIV
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M64-107,39.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV+,Hispanic/Latino,0.086075,0.877653,0.036272,BFI-0003156,False,F,Healthy-StanfordBloodCenter
M64-109,19.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV-,Caucasian,0.067132,0.876106,0.056762,BFI-0003158,False,M,Healthy-StanfordBloodCenter
M64-111,57.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV+,Caucasian,0.084105,0.877510,0.038385,BFI-0003160,False,M,Healthy-StanfordBloodCenter
M64-112,18.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV-,Caucasian,0.070577,0.897266,0.032157,BFI-0003161,False,M,Healthy-StanfordBloodCenter


Finally, let's run the model to predict the per-class probabilities at a _sample level_:

In [67]:
# Table of predicted class probabilities
pd.DataFrame(
    clf3.predict_proba(featurized_model3_data.X),
    index=featurized_model3_data.sample_names,
    columns=clf3.classes_,
)

Unnamed: 0_level_0,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
M111-S003,0.01,0.31,0.14,0.00,0.07,0.04
M111-S007,0.01,0.23,0.44,0.03,0.03,0.07
M111-S009,0.61,0.00,0.01,0.00,0.23,0.02
M111-S011,0.01,0.36,0.11,0.00,0.21,0.15
M111-S016,0.01,0.27,0.40,0.05,0.02,0.14
...,...,...,...,...,...,...
M64-107,0.00,0.01,0.67,0.09,0.05,0.20
M64-109,0.05,0.04,0.84,0.00,0.04,0.09
M64-111,0.09,0.03,0.59,0.05,0.02,0.19
M64-112,0.12,0.01,0.54,0.07,0.08,0.16


## Ensemble metamodel combining Models 1, 2, and 3 and combining BCR and TCR

Finally, let's load the ensemble metamodel, which brings all the other model components together for a final prediction of disease status.

We'll specify `metamodel_name="ridge_cv"` to load the ridge logistic regression version which we highlight in the paper. (The featurization uses the base models chosen in `config.metamodel_base_model_names` as we reviewed above, so the features will be the same regardless of `metamodel_name`. This parameter just controls which classification algorithm is used in the metamodel itself.)

We'll also specify `metamodel_flavor="default"`, which refers to combining Models 1, 2, and 3. Other `metamodel_flavor` options include:

- `subset_of_submodels_repertoire_stats` (Model 1 only)
- `subset_of_submodels_convergent_cluster_model` (Model 2 only)
- `subset_of_submodels_sequence_model` (Model 3 only)
- `subset_of_submodels_repertoire_stats_convergent_cluster_model` (Models 1 and 2 only)
- and so on.

In [68]:
clf_metamodel = BlendingMetamodel.from_disk(
    fold_id=fold_id,
    target_obs_column=target_obs_column,
    metamodel_name="ridge_cv",  # Which metamodel version to use
    base_model_train_fold_name="train_smaller",  # The base components are fitted on the train_smaller set
    metamodel_fold_label_train="validation",  # The metamodel is fitted on the validation set
    gene_locus=GeneLocus.BCR | GeneLocus.TCR,  # Use BCR and TCR components together
    metamodel_flavor="default",  # Use Models 1 + 2 + 3
)

clf_metamodel

Like we've seen with the other models, `BlendingMetamodel` is a wrapper around a scikit-learn Pipeline that confirms the expected features are present, standardizes them, and then runs ridge logistic regression:

In [69]:
clf_metamodel._inner

The features are per-class predicted probabilities from each submodel. For example, `BCR:repertoire_stats:Covid19` is `P(Covid-19)` according to the BCR version of Model 1, the repertoire summary statistics model. Here's the full list of features:

In [70]:
clf_metamodel.feature_names_in_

['BCR:repertoire_stats:Covid19',
 'BCR:repertoire_stats:HIV',
 'BCR:repertoire_stats:Healthy/Background',
 'BCR:repertoire_stats:Influenza',
 'BCR:repertoire_stats:Lupus',
 'BCR:repertoire_stats:T1D',
 'BCR:convergent_cluster_model:Covid19',
 'BCR:convergent_cluster_model:HIV',
 'BCR:convergent_cluster_model:Healthy/Background',
 'BCR:convergent_cluster_model:Influenza',
 'BCR:convergent_cluster_model:Lupus',
 'BCR:convergent_cluster_model:T1D',
 'BCR:sequence_model:Covid19',
 'BCR:sequence_model:HIV',
 'BCR:sequence_model:Healthy/Background',
 'BCR:sequence_model:Influenza',
 'BCR:sequence_model:Lupus',
 'BCR:sequence_model:T1D',
 'TCR:repertoire_stats:Covid19',
 'TCR:repertoire_stats:HIV',
 'TCR:repertoire_stats:Healthy/Background',
 'TCR:repertoire_stats:Influenza',
 'TCR:repertoire_stats:Lupus',
 'TCR:repertoire_stats:T1D',
 'TCR:convergent_cluster_model:Covid19',
 'TCR:convergent_cluster_model:HIV',
 'TCR:convergent_cluster_model:Healthy/Background',
 'TCR:convergent_cluster_model

As expected, there are three BCR and three TCR submodels:

In [71]:
clf_metamodel.metamodel_config.submodels

{<GeneLocus.BCR: 1>: {'repertoire_stats': RepertoireClassifier: Pipeline(steps=[('columntransformer',
                   ColumnTransformer(remainder='passthrough',
                                     transformers=[('log1p-scale-PCA_IGHG',
                                                    Pipeline(steps=[('log1p',
                                                                     FunctionTransformer(feature_names_out='one-to-one',
                                                                                         func=<ufunc 'log1p'>,
                                                                                         validate=True)),
                                                                    ('scale',
                                                                     StandardScalerThatPreservesInputType()),
                                                                    ('pca',
                                                                     PCA(n_compo

Let's featurize our input data to use with the metamodel. This time, the call to `featurize()` requires wrapping the input AnnData as a `dict[GeneLocus, AnnData]` — meaning a dictionary that maps from a sequencing locus to an AnnData object. This is because the BCR+TCR metamodel requires both AnnDatas at the same time to generate the features.

In [72]:
featurized_metamodel_data = clf_metamodel.featurize(
    {GeneLocus.BCR: adata_bcr, GeneLocus.TCR: adata_tcr}
)
type(featurized_metamodel_data)

{"message": "Metamodel featurization with data keys dict_keys([<GeneLocus.BCR: 1>, <GeneLocus.TCR: 2>]) and gene_locus GeneLocus.BCR|TCR: dropping specimens from GeneLocus.BCR anndata: {'M281redo-S061', 'M281redo-S018', 'M281redo-S058', 'M404-S014', 'M281redo-S054', 'M281redo-S056', 'M281redo-S026', 'M281redo-S045', 'M281redo-S041', 'M281redo-S048', 'M281redo-S019', 'M281redo-S040', 'M281redo-S014', 'M281redo-S017', 'M281redo-S007', 'M281redo-S051', 'M281redo-S008', 'M281redo-S029', 'M281redo-S023', 'M281redo-S020', 'M281redo-S034', 'M281redo-S057'}", "time": "2024-08-07T13:57:33.424991"}
{"message": "Number of VJGeneSpecificSequenceModelRollupClassifier featurization matrix N/As due to specimens not having any sequences with particular V/J gene pairs: 7470 / 125856 = 5.94%", "time": "2024-08-07T14:03:02.023751"}
{"message": "Number of VJGeneSpecificSequenceModelRollupClassifier featurization matrix N/As due to specimens not having any sequences with particular V/J gene pairs: 0 / 3091

crosseval.featurized_data.FeaturizedData

> _Aside:_
>
> There are two new log messages here worth highlighting:
>
> 1. The first line, with "dropping specimens from GeneLocus.BCR anndata": these samples are removed because they have BCR data only — no TCR data. The BCR+TCR metamodel only runs on samples that have both BCR and TCR data available.
> 2. The last line, with "Abstained specimens": Model 2 abstained from prediction for these samples, because none of the sequences in these samples matched any of the disease-associated clusters. The abstention propogates up into the metamodel.

Let's unwrap the `FeaturizedData` container as usual. This time, the features are the predicted class probabilities from the base models:

In [73]:
# Features
featurized_metamodel_data.X

Unnamed: 0_level_0,BCR:repertoire_stats:Covid19,BCR:repertoire_stats:HIV,BCR:repertoire_stats:Healthy/Background,BCR:repertoire_stats:Influenza,BCR:repertoire_stats:Lupus,BCR:repertoire_stats:T1D,BCR:convergent_cluster_model:Covid19,BCR:convergent_cluster_model:HIV,BCR:convergent_cluster_model:Healthy/Background,BCR:convergent_cluster_model:Influenza,...,TCR:convergent_cluster_model:Healthy/Background,TCR:convergent_cluster_model:Influenza,TCR:convergent_cluster_model:Lupus,TCR:convergent_cluster_model:T1D,TCR:sequence_model:Covid19,TCR:sequence_model:HIV,TCR:sequence_model:Healthy/Background,TCR:sequence_model:Influenza,TCR:sequence_model:Lupus,TCR:sequence_model:T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
M124-S066,0.004994,0.376223,0.576059,0.006669,0.002035,0.034018,0.112480,0.229213,0.191971,0.046151,...,0.000029,1.947732e-07,1.300925e-10,0.000037,0.00,0.80,0.22,0.06,0.01,0.04
M491-S071,0.000844,0.000456,0.011374,0.000062,0.462071,0.525192,0.103348,0.115077,0.291361,0.047232,...,0.613858,1.544074e-04,4.305656e-03,0.307040,0.21,0.08,0.12,0.05,0.08,0.32
M491-S017,0.206832,0.071715,0.206088,0.024512,0.050140,0.440714,0.114141,0.096168,0.323603,0.044744,...,0.184054,1.429102e-02,1.826828e-02,0.537486,0.20,0.05,0.08,0.02,0.15,0.21
M64-029,0.003075,0.000924,0.953790,0.023456,0.003283,0.015472,0.060075,0.058137,0.655783,0.023916,...,0.993845,9.257881e-06,1.073829e-07,0.003435,0.03,0.21,0.59,0.00,0.01,0.03
M64-026,0.005413,0.003711,0.941543,0.006129,0.038927,0.004276,0.093520,0.096803,0.435293,0.033410,...,0.999362,7.960142e-07,2.164846e-06,0.000635,0.05,0.03,0.62,0.01,0.05,0.06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M124-S027,0.001228,0.827938,0.113491,0.001046,0.056235,0.000061,0.128990,0.126311,0.222758,0.044442,...,0.217314,6.392670e-04,1.447242e-02,0.319662,0.02,0.61,0.34,0.03,0.00,0.03
M64-068,0.007300,0.002270,0.849172,0.004033,0.028994,0.108232,0.033721,0.045151,0.767739,0.018526,...,0.893119,2.317722e-09,4.543618e-07,0.106880,0.13,0.01,0.65,0.00,0.01,0.03
M64-085,0.132419,0.207171,0.583211,0.000300,0.073702,0.003197,0.052490,0.114812,0.441408,0.026659,...,0.423783,2.419594e-04,4.430096e-03,0.420812,0.06,0.04,0.56,0.01,0.04,0.09
M433-S059,0.000164,0.000109,0.000842,0.988949,0.000588,0.009348,0.103840,0.090363,0.288975,0.107768,...,0.000005,9.998619e-01,3.968216e-05,0.000082,0.02,0.36,0.00,0.27,0.00,0.25


In [74]:
# Ground truth
featurized_metamodel_data.y

array(['HIV', 'T1D', 'T1D', 'Healthy/Background', 'Healthy/Background',
       'T1D', 'Influenza', 'Healthy/Background', 'HIV', 'Covid19',
       'Healthy/Background', 'Covid19', 'HIV', 'HIV', 'HIV',
       'Healthy/Background', 'Healthy/Background', 'T1D', 'Lupus', 'HIV',
       'T1D', 'Healthy/Background', 'Covid19', 'Healthy/Background',
       'HIV', 'Healthy/Background', 'T1D', 'Influenza', 'Covid19',
       'Lupus', 'T1D', 'Healthy/Background', 'Healthy/Background', 'HIV',
       'Covid19', 'Lupus', 'HIV', 'Healthy/Background', 'Lupus',
       'Healthy/Background', 'HIV', 'Lupus', 'Lupus', 'Lupus',
       'Healthy/Background', 'Healthy/Background', 'Healthy/Background',
       'HIV', 'HIV', 'Lupus', 'Healthy/Background', 'Healthy/Background',
       'HIV', 'Lupus', 'Covid19', 'T1D', 'Healthy/Background',
       'Healthy/Background', 'Influenza', 'Influenza', 'T1D', 'HIV',
       'Influenza', 'HIV', 'Healthy/Background', 'T1D',
       'Healthy/Background', 'T1D', 'Healthy/Backgrou

In [75]:
# Sample names
featurized_metamodel_data.sample_names

Index(['M124-S066', 'M491-S071', 'M491-S017', 'M64-029', 'M64-026',
       'M491-S107', 'M433-S051', 'M64-078', 'M124-S072', 'M418-S010',
       ...
       'M491-S064', 'M491-S106', 'M64-077', 'M491-S057', 'M111-S048',
       'M124-S027', 'M64-068', 'M64-085', 'M433-S059', 'M491-S048'],
      dtype='object', name='specimen_label', length=181)

In [76]:
# Sample metadata
featurized_metamodel_data.metadata

Unnamed: 0_level_0,age,disease,disease.rollup,disease.separate_past_exposures,disease_severity,disease_subtype,ethnicity_condensed,isotype_proportion:IGHA,isotype_proportion:IGHD-M,isotype_proportion:IGHG,participant_label,past_exposure,sex,study_name
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
M124-S066,27.0,HIV,HIV,HIV,,HIV Non Neutralizing,African,0.230346,0.551237,0.218417,BFI-0003755,False,F,HIV
M491-S071,18.0,T1D,T1D,T1D,,T1D - pediatric,,0.178105,0.596202,0.225693,BFI-0010718,False,,Diabetes biobank
M491-S017,,T1D,T1D,T1D,,T1D - adult,,0.244323,0.574685,0.180992,BFI-0010665,False,,Diabetes biobank
M64-029,66.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV-,Caucasian,0.209289,0.541179,0.249532,BFI-0003078,False,M,Healthy-StanfordBloodCenter
M64-026,68.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV+,Asian,0.254202,0.569415,0.176384,BFI-0003075,False,M,Healthy-StanfordBloodCenter
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
M124-S027,28.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - HIV Negative,African,0.368556,0.410265,0.221179,BFI-0003724,False,M,HIV
M64-068,54.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV+,Asian,0.119946,0.779775,0.100280,BFI-0003117,False,M,Healthy-StanfordBloodCenter_included-in-resequ...
M64-085,55.0,Healthy/Background,Healthy/Background,Healthy/Background,,Healthy/Background - CMV+,Caucasian,0.146258,0.743095,0.110647,BFI-0003134,False,M,Healthy-StanfordBloodCenter
M433-S059,23.0,Influenza,Influenza,Influenza,,Influenza vaccine 2021 - day 7,,0.353308,0.433732,0.212960,BFI-0009975,False,F,Flu vaccine UPenn 2021


And we'll run our final prediction, like with the other models:

In [77]:
# Table of predicted class probabilities
pd.DataFrame(
    clf_metamodel.predict_proba(featurized_metamodel_data.X),
    index=featurized_metamodel_data.sample_names,
    columns=clf_metamodel.classes_,
)

Unnamed: 0_level_0,Covid19,HIV,Healthy/Background,Influenza,Lupus,T1D
specimen_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
M124-S066,0.069313,0.475842,0.164825,0.097305,0.089503,0.103212
M491-S071,0.109962,0.101471,0.114104,0.084598,0.206145,0.383721
M491-S017,0.154769,0.124511,0.157118,0.096315,0.151228,0.316059
M64-029,0.081137,0.121118,0.497514,0.099606,0.112362,0.088264
M64-026,0.088994,0.118889,0.435706,0.117980,0.131329,0.107103
...,...,...,...,...,...,...
M124-S027,0.088240,0.416703,0.152640,0.144919,0.095025,0.102473
M64-068,0.121266,0.114451,0.412180,0.095271,0.132779,0.124052
M64-085,0.096077,0.148324,0.327083,0.170974,0.145739,0.111802
M433-S059,0.021166,0.036373,0.033421,0.801682,0.033510,0.073848


We'll end by calculating the AUROC for fold 0:

In [78]:
import sklearn.metrics

sklearn.metrics.roc_auc_score(
    y_true=featurized_metamodel_data.y,
    y_score=clf_metamodel.predict_proba(featurized_metamodel_data.X),
    multi_class="ovo",  # Multiclass AUC calculated in one-versus-one fashion
    average="weighted",  # Take class size-weighted average of the binary AUROC calculated for each pair of classes
)

0.9878390256152686

See the readme for instructions about customizing Mal-ID for additional datasets and classification targets.