# Tutorial-3: Evaluating Modality Alignment
In this tutorial, we demonstrate how to calculate metrics for modality alignment. 

The modality alignment metrics comprise modality averaged silhouette width (ASW; $y^{\text{ASW}}$), fraction of samples closer than the true match (FOSCTTM; $y^{\text{FOSCTTM}}$), label transfer F1 ($y^{\text{ltF1}}$), ATAC area under the receiver operating characteristic (AUROC; $y^{\text{AUROC}}$), RNA Pearson’s r ($y^{\text{RNAr}}$) and ADT Pearson’s r ($y^{\text{ADTr}}$), where $y^{\text{ASW}}$, $y^{\text{FOSCTTM}}$ and $y^{\text{ltF1}}$ are defined in embedding space, and $y^{\text{AUROC}}$, $y^{\text{RNAr}}$ and $y^{\text{ADTr}}$ are defined in feature space.


Before executing the code in this tutorial, it is essential to complete Tutorial-2, as this tutorial builds upon the training results obtained there. Unlike common metrics, calculating modality alignment requires inputting the full dataset instead of mosaic data. Using the full dataset, we predict the modality-specific embeddings and translated counts, which will be evaluated.


## Dataset Preparation
Firstly, we construct a full dataset (~7min):

In [1]:
import warnings
warnings.filterwarnings('ignore')

from scmidas.datasets import GenDataFromPath

In [2]:
data_path = [
    {
        "rna": "./data_3/processed/dogma_demo/subset_0/mat/rna.csv", 
        "adt": "./data_3/processed/dogma_demo/subset_0/mat/adt.csv",
        "atac": "./data_3/processed/dogma_demo/subset_0/mat/atac.csv"
     },
    {
        "rna": "./data_3/processed/dogma_demo/subset_1/mat/rna.csv", 
        "adt": "./data_3/processed/dogma_demo/subset_1/mat/adt.csv",
        "atac": "./data_3/processed/dogma_demo/subset_1/mat/atac.csv"
     },

    {
        "rna": "./data_3/processed/dogma_demo/subset_2/mat/rna.csv", 
        "adt": "./data_3/processed/dogma_demo/subset_2/mat/adt.csv",
        "atac": "./data_3/processed/dogma_demo/subset_2/mat/atac.csv"
     },
]
save_dir = "./data_3/processed/dogma_demo_transfer/"
remove_old = False
GenDataFromPath(data_path, save_dir, remove_old) # generate a directory, can be substituted by preprocess/split_mat.py

Spliting rna matrix: 7361 cells, 4054 features


100%|██████████| 7361/7361 [00:12<00:00, 601.43it/s]


Spliting adt matrix: 7361 cells, 208 features


100%|██████████| 7361/7361 [00:04<00:00, 1839.88it/s]


Spliting atac matrix: 7361 cells, 30521 features


100%|██████████| 7361/7361 [01:01<00:00, 119.99it/s]


Spliting rna matrix: 5897 cells, 4054 features


100%|██████████| 5897/5897 [00:10<00:00, 574.41it/s]


Spliting adt matrix: 5897 cells, 208 features


100%|██████████| 5897/5897 [00:03<00:00, 1618.78it/s]


Spliting atac matrix: 5897 cells, 30521 features


100%|██████████| 5897/5897 [00:47<00:00, 124.45it/s]


Spliting rna matrix: 10190 cells, 4054 features


100%|██████████| 10190/10190 [00:17<00:00, 584.73it/s]


Spliting adt matrix: 10190 cells, 208 features


100%|██████████| 10190/10190 [00:05<00:00, 1702.19it/s]


Spliting atac matrix: 10190 cells, 30521 features


100%|██████████| 10190/10190 [01:21<00:00, 124.73it/s]


## Pretrained Model Preparation
In this section, we initialize the model with pretrained weights from the tutorial-2.

In [3]:
from scmidas.models import MIDAS
from scmidas.datasets import GetDataInfo
import scmidas.utils as utils
import scanpy as sc
import pandas as pd

sc.set_figure_params(figsize=(4, 4))

In [None]:
data = [GetDataInfo("./data_3/processed/dogma_demo_transfer/")]
model = MIDAS(data)
model.init_model(model_path="./result/dogma_demo/train/sp_00000500.pt", skip_s=True) # skip_s_enc is set to True to avoid dismatch structure.

## Prediction
In this section, we predict the embeddings for each modality and the translated counts, as well as inputs (used when calculating correlation). ~24min

In [7]:
model.predict(mod_latent=True, translate=True, input=True, save_dir=f'./result/dogma_demo/predict/', remove_old=False)

Predicting ...
Processing subset 0: ['atac', 'rna', 'adt']


100%|██████████| 29/29 [07:43<00:00, 15.99s/it]


Processing subset 1: ['atac', 'rna', 'adt']


100%|██████████| 24/24 [06:07<00:00, 15.32s/it]


Processing subset 2: ['atac', 'rna', 'adt']


100%|██████████| 40/40 [10:45<00:00, 16.15s/it]


## Evaluation
~50min

In [8]:
from scmidas.evaluation import eval_mod

In [10]:
pred = model.read_preds(mod_latent=True, translate=True, input=True, group_by="subset")
label_list = [
    pd.read_csv(f'./data/raw/atac+rna+adt/dogma/lll_ctrl/label_seurat/l1.csv', index_col=0).values.flatten(),
    pd.read_csv(f'./data/raw/atac+rna+adt/dogma/lll_stim/label_seurat/l1.csv', index_col=0).values.flatten(),
    pd.read_csv(f'./data/raw/atac+rna+adt/dogma/dig_ctrl/label_seurat/l1.csv', index_col=0).values.flatten(),
    ]
result = eval_mod(pred, label_list, model.masks)
result

Loading predicted variables ...


100%|██████████| 29/29 [00:00<00:00, 198.88it/s]
100%|██████████| 29/29 [00:00<00:00, 271.21it/s]
100%|██████████| 29/29 [00:00<00:00, 263.84it/s]
100%|██████████| 29/29 [00:00<00:00, 265.69it/s]
100%|██████████| 29/29 [00:00<00:00, 66.94it/s]
100%|██████████| 29/29 [00:06<00:00,  4.52it/s]
100%|██████████| 29/29 [00:00<00:00, 35.11it/s]
100%|██████████| 29/29 [00:47<00:00,  1.63s/it]
100%|██████████| 29/29 [00:50<00:00,  1.75s/it]
100%|██████████| 29/29 [00:07<00:00,  3.69it/s]
100%|██████████| 29/29 [00:00<00:00, 88.10it/s]
100%|██████████| 29/29 [00:06<00:00,  4.71it/s]
100%|██████████| 29/29 [00:43<00:00,  1.51s/it]
100%|██████████| 29/29 [00:13<00:00,  2.11it/s]
100%|██████████| 29/29 [00:01<00:00, 14.74it/s]
100%|██████████| 29/29 [00:00<00:00, 108.64it/s]
100%|██████████| 24/24 [00:00<00:00, 181.13it/s]
100%|██████████| 24/24 [00:00<00:00, 214.18it/s]
100%|██████████| 24/24 [00:00<00:00, 205.20it/s]
100%|██████████| 24/24 [00:00<00:00, 196.88it/s]
100%|██████████| 24/24 [00:00<0

Converting to numpy ...
Converting subset 0: s, joint
Converting subset 0: s, atac
Converting subset 0: s, rna
Converting subset 0: s, adt
Converting subset 0: z, joint
Converting subset 0: z, atac
Converting subset 0: z, rna
Converting subset 0: z, adt
Converting subset 0: x_trans, atac_to_adt
Converting subset 0: x_trans, atac_to_rna
Converting subset 0: x_trans, rna_to_adt
Converting subset 0: x_trans, rna_to_atac
Converting subset 0: x_trans, adt_to_atac
Converting subset 0: x_trans, adt_to_rna
Converting subset 0: x_trans, atac_rna_to_adt
Converting subset 0: x_trans, atac_adt_to_rna
Converting subset 0: x_trans, rna_adt_to_atac
Converting subset 0: x, atac
Converting subset 0: x, rna
Converting subset 0: x, adt
Converting subset 1: s, joint
Converting subset 1: s, atac
Converting subset 1: s, rna
Converting subset 1: s, adt
Converting subset 1: z, joint
Converting subset 1: z, atac
Converting subset 1: z, rna
Converting subset 1: z, adt
Converting subset 1: x_trans, atac_to_adt
C

  result = eval_modality_alignment(pred, label_list, model.masks)


mean silhouette per group:          silhouette_score
group                    
B                0.840390
CD4 T            0.830041
CD8 T            0.872103
DC               0.792754
Mono             0.822331
NK               0.876750
other            0.831161
other T          0.882571
calculating batch 2/3


  result = eval_modality_alignment(pred, label_list, model.masks)


mean silhouette per group:          silhouette_score
group                    
B                0.888698
CD4 T            0.913682
CD8 T            0.918087
DC               0.847209
Mono             0.885915
NK               0.935011
other            0.901474
other T          0.927728
calculating batch 3/3


  result = eval_modality_alignment(pred, label_list, model.masks)


mean silhouette per group:          silhouette_score
group                    
B                0.854576
CD4 T            0.858451
CD8 T            0.894141
DC               0.849380
Mono             0.850341
NK               0.843207
other            0.807512
other T          0.834351


{'asw_mod': {'0': 0.8435125604491478,
  '1': 0.9022253320396332,
  '2': 0.8489948476263821},
 'foscttm': {'0': {'adt_to_atac': 0.802176371216774,
   'adt_to_rna': 0.8414446413516998,
   'rna_to_adt': 0.8613581210374832,
   'rna_to_atac': 0.8992796763777733,
   'atac_to_adt': 0.7986521124839783,
   'atac_to_rna': 0.8574737459421158},
  '1': {'adt_to_atac': 0.8391045033931732,
   'adt_to_rna': 0.8537391573190689,
   'rna_to_adt': 0.8575942367315292,
   'rna_to_atac': 0.9999860794387132,
   'atac_to_adt': 0.8538448065519333,
   'atac_to_rna': 0.9999953118767735},
  '2': {'adt_to_atac': 0.8584581613540649,
   'adt_to_rna': 0.9867734359577298,
   'rna_to_adt': 0.9937428142875433,
   'rna_to_atac': 0.8907932043075562,
   'atac_to_adt': 0.8049696385860443,
   'atac_to_rna': 0.8308574855327606}},
 'f1': {'0': {'adt_to_atac': 0.8319521804102703,
   'adt_to_rna': 0.9011003939682108,
   'rna_to_adt': 0.9011003939682108,
   'rna_to_atac': 0.8725716614590409,
   'atac_to_adt': 0.875967939138704,
  