In [1]:
from pathlib import Path
from summarynb import show, table, chunks
from malid.external.summarynb_extras import plaintext
from malid import config, logger
from malid.train import train_metamodel
from malid.datamodels import (
    GeneLocus,
    TargetObsColumnEnum,
    map_cross_validation_split_strategy_to_default_target_obs_column,
)
import pandas as pd
from IPython.display import display, Markdown
from typing import Optional, List

In [2]:
config.embedder.name

'esm2_cdr3'

In [3]:
# get default classification target, e.g. TargetObsColumnEnum.disease
default_target_obs_column = (
    map_cross_validation_split_strategy_to_default_target_obs_column[
        config.cross_validation_split_strategy
    ]
)
default_target_obs_column

<TargetObsColumnEnum.disease: TargetObsColumn(obs_column_name='disease', is_target_binary_for_repertoire_composition_classifier=False, available_for_cross_validation_split_strategies={<CrossValidationSplitStrategy.adaptive_peak_disease_timepoints_leave_some_cohorts_out: CrossValidationSplitStrategyValue(data_sources_keep=[<DataSource.adaptive: 2>], stratify_by='disease', diseases_to_keep_all_subtypes=['Covid19', 'Healthy/Background'], subtypes_keep=[], filter_specimens_func_by_study_name={}, gene_loci_supported=<GeneLocus.TCR: 2>, exclude_study_names=[], include_study_names=['emerson-2017-natgen_train', 'immunecode-NIH', 'immunecode-ISB', 'emerson-2017-natgen_validation', 'immunecode-HUniv'], filter_out_specimens_funcs_global=[], study_names_for_held_out_set=['emerson-2017-natgen_validation', 'immunecode-HUniv'])>, <CrossValidationSplitStrategy.in_house_peak_disease_leave_another_lupus_cohort_out: CrossValidationSplitStrategyValue(data_sources_keep=[<DataSource.in_house: 1>], stratify_

# Metamodel per-class OvR ROC AUC scores

Abstentions consistent across metamodel flavors from the same gene locus. e.g. the model1-only metamodel is forced to abstain wherever the model2-only metamodel abstained.

Color scales consistent between metamodel names.

Also includes a row for the average-across-class-pairs OvO score we normally report, but again now with consistent abstentions across flavors.

Note that BCR and BCR+TCR are not comparable, because the sample sizes are different. (TCR and BCR+TCR have the same sample size because there are no TCR-only cohorts, but still are not comparable because abstentions are not forced to be identical across GeneLocus settings.)

In [4]:
for model_name in [
    "lasso_cv",
    "elasticnet_cv",
    "ridge_cv",
    "rf_multiclass",
    "linearsvm_ovr",
]:
    show(
        config.paths.second_stage_blending_metamodel_output_dir
        / f"{default_target_obs_column.name}.roc_auc_per_class.{model_name}.png",
        max_width=1200,
    )

# Metamodel pairwise ROC AUC scores

Headers are the multiclass weighted ROC AUC scores

In [5]:
def pairwise_summary(
    gene_locus: GeneLocus,
    target_obs_column: TargetObsColumnEnum,
):
    base_model_train_fold_name = "train_smaller"
    metamodel_fold_label_train = "validation"
    display(Markdown(f"## {gene_locus}, {target_obs_column}"))
    try:
        flavors = train_metamodel.get_metamodel_flavors(
            gene_locus=gene_locus,
            target_obs_column=target_obs_column,
            fold_id=config.all_fold_ids[0],
            base_model_train_fold_name=base_model_train_fold_name,
            use_stubs_instead_of_submodels=True,
        )
    except Exception as err:
        logger.warning(
            f"Failed to generate metamodel flavors for {gene_locus}, {target_obs_column}: {err}"
        )
        return
    for metamodel_flavor, metamodel_config in flavors.items():
        _output_suffix = (
            Path(gene_locus.name)
            / target_obs_column.name
            / metamodel_flavor
            / f"{base_model_train_fold_name}_applied_to_{metamodel_fold_label_train}_model"
        )
        results_output_prefix = (
            config.paths.second_stage_blending_metamodel_output_dir / _output_suffix
        )
        highres_results_output_prefix = (
            config.paths.high_res_outputs_dir / "metamodel" / _output_suffix
        )

        # Load multiclass ROC AUC scores for these models
        # (show mean across cross validation folds - remove +/- stddev)
        try:
            multiclass_scores = (
                pd.read_csv(
                    f"{results_output_prefix}.compare_model_scores.test_set_performance.tsv",
                    sep="\t",
                    index_col=0,
                )["ROC-AUC (weighted OvO) per fold"]
                .str.split(" +/-", regex=False)
                .str[0]
            )
        except FileNotFoundError as err:
            # This can happen because use_stubs_instead_of_submodels=True above means some non-existent flavors will be generated
            logger.warning(
                f"File not found for {gene_locus}, {target_obs_column}, flavor {metamodel_flavor}: {err}"
            )
            continue

        display(Markdown(f"#### {metamodel_flavor}"))

        model_names = config.model_names_to_train
        show(
            [
                f"{highres_results_output_prefix}.pairwise_roc_auc_scores.{model_name}.png"
                for model_name in model_names
            ],
            max_width=400,
            headers=[
                f"{model_name}: {multiclass_scores.loc[model_name]}"
                for model_name in model_names
            ],
        )

In [6]:
pairwise_summary(
    gene_locus=config.gene_loci_used, target_obs_column=default_target_obs_column
)

## GeneLocus.BCR|TCR, TargetObsColumnEnum.disease

#### default

dummy_most_frequent: 0.500,dummy_stratified: 0.494,lasso_cv: 0.980,elasticnet_cv0.75: 0.978,elasticnet_cv: 0.981,elasticnet_cv0.25: 0.984,ridge_cv: 0.986,logisticregression_unregularized: 0.975,rf_multiclass: 0.981,linearsvm_ovr: 0.967
,,,,,,,,,


#### subset_of_submodels_repertoire_stats

dummy_most_frequent: 0.500,dummy_stratified: 0.502,lasso_cv: 0.968,elasticnet_cv0.75: 0.968,elasticnet_cv: 0.967,elasticnet_cv0.25: 0.967,ridge_cv: 0.971,logisticregression_unregularized: 0.932,rf_multiclass: 0.970,linearsvm_ovr: 0.962
,,,,,,,,,


#### subset_of_submodels_convergent_cluster_model

dummy_most_frequent: 0.500,dummy_stratified: 0.494,lasso_cv: 0.930,elasticnet_cv0.75: 0.928,elasticnet_cv: 0.925,elasticnet_cv0.25: 0.924,ridge_cv: 0.922,logisticregression_unregularized: 0.920,rf_multiclass: 0.924,linearsvm_ovr: 0.930
,,,,,,,,,


#### subset_of_submodels_sequence_model

dummy_most_frequent: 0.500,dummy_stratified: 0.495,lasso_cv: 0.971,elasticnet_cv0.75: 0.971,elasticnet_cv: 0.972,elasticnet_cv0.25: 0.974,ridge_cv: 0.975,logisticregression_unregularized: 0.942,rf_multiclass: 0.965,linearsvm_ovr: 0.966
,,,,,,,,,


#### subset_of_submodels_repertoire_stats_convergent_cluster_model

dummy_most_frequent: 0.500,dummy_stratified: 0.494,lasso_cv: 0.973,elasticnet_cv0.75: 0.975,elasticnet_cv: 0.977,elasticnet_cv0.25: 0.978,ridge_cv: 0.981,logisticregression_unregularized: 0.963,rf_multiclass: 0.979,linearsvm_ovr: 0.969
,,,,,,,,,


#### subset_of_submodels_repertoire_stats_sequence_model

dummy_most_frequent: 0.500,dummy_stratified: 0.502,lasso_cv: 0.975,elasticnet_cv0.75: 0.978,elasticnet_cv: 0.980,elasticnet_cv0.25: 0.980,ridge_cv: 0.982,logisticregression_unregularized: 0.968,rf_multiclass: 0.977,linearsvm_ovr: 0.965
,,,,,,,,,


#### subset_of_submodels_convergent_cluster_model_sequence_model

dummy_most_frequent: 0.500,dummy_stratified: 0.494,lasso_cv: 0.976,elasticnet_cv0.75: 0.981,elasticnet_cv: 0.983,elasticnet_cv0.25: 0.981,ridge_cv: 0.982,logisticregression_unregularized: 0.970,rf_multiclass: 0.976,linearsvm_ovr: 0.965
,,,,,,,,,


#### isotype_counts_only

dummy_most_frequent: 0.500,dummy_stratified: 0.502,lasso_cv: 0.635,elasticnet_cv0.75: 0.637,elasticnet_cv: 0.646,elasticnet_cv0.25: 0.651,ridge_cv: 0.664,logisticregression_unregularized: 0.671,rf_multiclass: 0.679,linearsvm_ovr: 0.683
,,,,,,,,,
