In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -e ./../../BatchDetect

In [3]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Patch-level methods evaluation

Evaluate the performance of patch-level methods, such as stain normalization or stain augmentation techniques.

## Reading metadata

In [4]:
dataset = 'CRC'  # TODO make work for other datasets

In [5]:
# create metadata dataframe from clini_table and folder structure
base_dir = Path(f'/lustre/groups/shared/users/peng_marr/BatchDetect/BatchDetect{dataset}')
# clini_table = pd.read_csv(base_dir / 'BatchDetectCRC_clini.csv')
clini_table = pd.read_csv(base_dir.parent / 'BatchDetectCRC_clini.csv')

labels = list(clini_table.columns)  # or costum list
labels.remove('PATIENT')
labels.remove('AGE')

metadata_path = Path(base_dir / 'metadata.csv')
if metadata_path.exists():
    metadata = pd.read_csv(metadata_path)
else:
    # metadata with columns: file, label (MSI-H), submission site
    patch_list = list(base_dir.glob('**/*.jpeg'))
    print('Number of patches:', len(patch_list))

    submission_site = [patch.parent.parent.name for patch in patch_list]
    metadata = pd.DataFrame(list(zip(patch_list, submission_site)), columns=['file', 'dataset'])

    for l in labels:
        label = [clini_table[l][clini_table['PATIENT'] == patch.name.split('_')[0]].item() for patch in patch_list]
        metadata[l] = label
    metadata.to_csv(metadata_path, index=False)

In [6]:
# for TCGA-CRC cohorts
if dataset == 'CRC':
    from pathlib import Path
    metadata["type"] = metadata["file"].astype(str).apply(lambda x: Path(x).parent.name.split(".")[0].split("-")[-1])
    metadata["type"] = metadata["type"].apply(lambda x: "FFPE" if x.startswith("DX") else x)
    metadata["type"] = metadata["type"].apply(lambda x: "frozen" if x.startswith("TS") else x)
    metadata["type"] = metadata["type"].apply(lambda x: "frozen" if x.startswith("BS") else x)
    # map type to frozen if dataset == CPTAC
    metadata["type"] = metadata.apply(lambda x: "frozen" if x["dataset"] == "CPTAC" else x["type"], axis=1)
    # map all entries that are not frozen or FFPE to ""
    metadata["type"] = metadata["type"].apply(lambda x: x if x in ["frozen", "FFPE"] else np.nan)
    # metadata.to_csv(Path(base_dir / 'metadata.csv'), index=False)
    labels = labels + ["type"]

In [None]:
metadata

## Features
Create or load features

In [109]:
method = "original"  # no batch correction method is applied
features = 'ufirst_and_second_orderni' 

In [None]:
from batchdetect.image import first_and_second_order, resnet, ctranspath, h_optimus_0, h0_mini, h_optimus_1, uni2, conch, uni

df_features_path = base_dir / f'{method}_{features}_features.csv'

if df_features_path.exists():
    df_features = pd.read_csv(df_features_path)
else:
    if features == 'first_and_second_order':
        df_features = first_and_second_order(metadata)
    elif features == 'resnet':
        df_features = resnet(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'ctranspath':
        df_features = ctranspath(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'h_optimus_0':
        df_features = h_optimus_0(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'h0_mini':
        df_features = h0_mini(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0)[:, 0, :])
    elif features == 'h_optimus_1':
        df_features = h_optimus_1(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'uni2':
        df_features = uni2(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'uni':
        df_features = uni(metadata)
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    elif features == 'conch':
        from tqdm import tqdm 
        from PIL import Image
        # custom forward pass
        model, transform = conch()
        file_list = metadata["file"].tolist()
        df_features = []
        for f in tqdm(file_list):
            image = Image.open(f).convert('RGB')

            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    image = transform(image).unsqueeze(0)
                    image = image.to(device)
                    df_features.append(model.encode_image(image, proj_contrast=False, normalize=False).squeeze().cpu().numpy())
        df_features = pd.DataFrame(np.stack(df_features, axis=0))
    
    df_features.to_csv(df_features_path, index=False)

## Let's see if there is a batch effect in the data

In [None]:
from batchdetect.batchdetect import BatchDetect

bd = BatchDetect(metadata.loc[:, [*labels, "dataset"]], df_features)

### Visualizations

In [None]:
bd.low_dim_visualization("pca")

In [None]:
bd.low_dim_visualization("tsne")

In [117]:
bd.low_dim_visualization("umap")  

### Anova test of principal components vs. labels

In [None]:
bd.prince_plot()

### classification test of random forest (RF) vs a random classifier

In [None]:
bd.classification_test(scorer="f1_macro")

### Clustering metrics

In [None]:
targets = [*labels, "dataset"]

from batchdetect.metrics import mean_local_diversity, silhouette_score

metrics = [mean_local_diversity, silhouette_score]
metrics_labels = [f'{m.__name__}'.replace('_', ' ').title() for m in metrics]

result_df = pd.DataFrame(columns=["Target", *metrics_labels])

for i, m in enumerate(metrics):
    res = m(metadata, targets, df_features)
    for j, t in enumerate(targets):
        result_df.loc[j, "Target"] = t
        result_df.loc[j, metrics_labels[i]] = res[t]


In [None]:
result_df

| Metrics | Range | Aim for highest mixture |
|:-------------|:--------------:|:--------------:|
| Mean local diversity | [0, 1] | 1 |
| Silhouette score | [-1 ,1] | -1 |

In [None]:
labels