### Dependencies

In [19]:
import os
import pickle
import sys

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset

### How To Use
1. Follow the README in "2-Weakly-Supervised-Train-Val" for extracting instance-level embeddings for ResNet-50, ViT-16, and ViT-256 features
1. Pre-extracted mean features for each slide dataset are available in "embeddings_slide_library" on Google Drive
2. Download "embeddings_slide_library" to this containing directory
3. Run

### Helper Functions for Representating each WSI using the Mean Instance-Level Embeddings

In [2]:
def series_intersection(s1, s2):
    r"""
    Takes the intersection of two pandas.Series (pd.Series) objects.
    
    Args:
        - s1 (pd.Series): pd.Series object.
        - s2 (pd.Series): pd.Series object.
    Return:
        - pd.Series: Intersection of s1 and s2.
    """
    return pd.Series(list(set(s1) & set(s2)))

def save_embeddings_mean(fname, dataset):
    r"""
    Saves+Pickle each WSI in a CSVDataset Object as the average of its instance-level embeddings
    
    Args:
        - fname (str): Save path+name for the pickle object.
        - dataset (torch.utils.data.dataset): CSVDataset object that iterates+loads each WSI in a folder
    
    Return:
        - None
    """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
    embeddings, labels = [], []

    for batch, target in dataloader:
        with torch.no_grad():
            embeddings.append(batch.squeeze(dim=0).mean(dim=0).numpy())
            labels.append(target.numpy())
            
    embeddings = np.vstack(embeddings)
    labels = np.vstack(labels).squeeze()

    asset_dict = {'embeddings': embeddings, 'labels': labels}

    with open('%s.pkl' % (fname), 'wb') as handle:
        pickle.dump(asset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)


class CSVDataset(Dataset):
    r"""
    torch.utils.data.dataset Object that iterates+loads each WSI in a folder.
    
    Args:
        - dataroot (str): Path to wsi_labels.csv.
        - tcga_csv (pd.DataFrame): Clinical CSV (as a pd.DataFrame object) for a TCGA Study
        - pt_path (str): Path to folder of saved instance-level feature embeddings for each WSI.
        - splits_csv (pd.DataFrame): DataFrame which contains slide_ids for train / val / test
        - label_col (str): Which column to use as labels in tcga_csv
        - label_dict (dict): Dictionary for categorizing labels
    Return:
        - None
    """
    def __init__(self, dataroot, tcga_csv, pt_path, splits_csv=None,
                 label_col='oncotree_code', label_dict={'LUSC':0, 'LUAD':1}):
        
        self.csv = pd.read_csv(os.path.join(dataroot, 'wsi_labels.csv'))
        self.csv['slide_path'] = pt_path+self.csv['slide_id']
        self.csv = self.csv.set_index('slide_id', drop=True).drop(['Unnamed: 0'], axis=1)
        self.csv.index = self.csv.index.str[:-3]
        self.csv.index.name = None
        self.csv = self.csv.join(tcga_csv, how='inner')
        if splits_csv is not None:
            self.csv = self.csv.loc[series_intersection(splits_csv.dropna(), self.csv.index)]
            
        self.label_col = label_col
        self.label_dict = label_dict
        
        ### If using DINO Features, subset and use only the last 384-dim features.
        if 'dino_pt_patch_features' in pt_path:
            self.truncate = True
        else:
            self.truncate = False
            
    def __getitem__(self, index):
        x = torch.load(self.csv['slide_path'][index])
        if self.truncate:
            x = x[:,(1536-384):1536]
        label = torch.Tensor([self.label_dict[self.csv[self.label_col][index]]]).to(torch.long)
        return x, label
    
    def __len__(self):
        return self.csv.shape[0]

In [7]:
r"""
Script for saving mean WSI features for each feature type in each task
"""

dataroot = 'path/to/dataroot'

for enc_name in ['resnet50mean', 'vit16mean', 'vit256mean']
    for study in ['tcga_brca', 'tcga_kidney', 'tcga_lung']:
        tcga_csv = pd.read_csv("../Classification/dataset_csv/%s_subset.csv" % study, index_col=2)['oncotree_code']
        tcga_csv.index = tcga_csv.index.str[:-4]
        tcga_csv.index.name = None

        saveroot = os.path.join('path/to/saved/features', enc_name) + '/'
        os.makedirs(saveroot, exist_ok=True)

        if enc_name == 'vit256mean':
            pt_path = os.path.join(dataroot, 'vits_tcga_pancancer_dino_pt_global/')
        elif enc_name == 'vit16mean':
            extracted_dir = '%s/extracted_mag20x_patch256_fp/vits_tcga_pancancer_dino_pt_patch_features/' % study
            pt_path = os.path.join(dataroot, extracted_dir)
        elif enc_name == 'resnet50mean':
            extracted_dir = '%s/extracted_mag20x_patch256_fp/resnet50_trunc_pt_patch_features/' % study
            pt_path = os.path.join(dataroot, extracted_dir)

        splits_folder = '../Classification/splits/10foldcv_subtype/%s/' % study
        if study == 'tcga_brca':
            label_dict={'IDC':0, 'ILC':1}
            tcga_csv = tcga_csv[tcga_csv.str.contains('IDC|ILC')]
        elif study == 'tcga_kidney':
            label_dict={'CCRCC':0, 'PRCC':1, 'CHRCC': 2}
        elif study == 'tcga_lung':
            label_dict={'LUSC':0, 'LUAD':1}

        for i in tqdm(range(10)):
            splits_csv = pd.read_csv(os.path.join(splits_folder, 'splits_%d.csv' % i), index_col=0)
            train = CSVDataset(dataroot=dataroot, tcga_csv=tcga_csv, pt_path=pt_path,
                                       splits_csv=splits_csv['train'], label_dict=label_dict)
            test = CSVDataset(dataroot=dataroot, tcga_csv=tcga_csv, pt_path=pt_path,
                                      splits_csv=splits_csv['test'], label_dict=label_dict)

            save_embeddings_mean(saveroot+'/%s_%s_class_split_train_%d' % (study, enc_name, i), train)
            save_embeddings_mean(saveroot+'/%s_%s_class_split_test_%d' % (study, enc_name, i), test)

100%|██████████| 10/10 [35:44<00:00, 214.49s/it]
100%|██████████| 10/10 [1:11:46<00:00, 430.68s/it]
100%|██████████| 10/10 [46:58<00:00, 281.81s/it]


### 10-Fold CV Evaluation of Mean WSI Embeddings

In [18]:
def get_results(saveroot, study='tcga_lung', enc_name='vit256mean', prop=1.0):
    r"""
    Runs 10-fold CV for KNN of mean WSI embeddings
    
    Args:
        - saveroot (str): Path to mean WSI embeddings for each feature type.
        - study (str): Which TCGA study (Choices: tcga_brca, tcga_lung, tcga_kidney)
        - enc_name (str): Which encoder to use (Choices: resnet50mean, vit16mean, vit256mean)
        - prop (float): Proportion of training dataset to use
    Return:
        - aucs_knn_all (pd.DataFrame): AUCs for 10-fold CV evaluation
    """
    aucs_knn_all = {}

    for i in range(10):
        train_fname = os.path.join(saveroot, enc_name, '%s_%s_split_train_%d.pkl' % (study, enc_name, i)) 
        with open(train_fname, 'rb') as handle:
            asset_dict = pickle.load(handle)
            train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']
            
            if prop < 1:
                sample_inds = pd.DataFrame(range(train_embeddings.shape[0])).sample(frac=0.1, random_state=1).index
                train_embeddings = train_embeddings[sample_inds]
                train_labels = train_labels[sample_inds]

        val_fname = os.path.join(saveroot, enc_name, '%s_%s_split_test_%d.pkl' % (study, enc_name, i)) 
        with open(val_fname, 'rb') as handle:
            asset_dict = pickle.load(handle)
            val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']

        le = LabelEncoder().fit(train_labels)
        train_labels = le.transform(train_labels)
        val_labels = le.transform(val_labels)

        ### K-NN Evaluation
        clf = KNeighborsClassifier().fit(train_embeddings, train_labels)
        y_score = clf.predict_proba(val_embeddings)
        y_pred = clf.predict(val_embeddings)
        aucs, f1s = [], []
        if len(np.unique(val_labels)) > 2:
            for j, label in enumerate(np.unique(val_labels)):
                label_class = np.array(val_labels == label, int)
            aucs.append(sklearn.metrics.roc_auc_score(val_labels, y_score, average='macro', multi_class='ovr'))
        else:
            aucs.append(sklearn.metrics.roc_auc_score(val_labels, y_score[:,1]))
        aucs_knn_all[i] = aucs

    aucs_knn_all = pd.DataFrame(aucs_knn_all).T
    return aucs_knn_all

In [17]:
r"""
Script for runnign 10-fold CV for each feature type for each TCGA study.
"""
    
results_all = []
saveroot = './embeddings_slide_library/'

for enc_name in tqdm(['resnet50mean', 'vit16mean', 'vit256mean']):
    results_row = []
    for study in ['tcga_brca', 'tcga_lung', 'tcga_kidney']:
        for prop in [0.25, 1.0]:
            aucs = get_results(saveroot, study, enc_name, prop)
            aucs = '%0.3f +/- %0.3f' % (aucs.mean(), aucs.std())
            results_row.append([aucs])
    
    results_all.append(pd.DataFrame(results_row).T)
    
results_df = pd.concat(results_all)
results_df.index = ['resnet50mean', 'vit16mean', 'vit256mean']
results_df.columns = [0.25, 1.0, 0.25, 1.0, 0.25, 1.0]
results_df.index = ['', '', '']
results_df.insert(0, 'Pretrain', ['ImageNet', 'DINO', 'DINO'])
results_df.insert(1, 'Arch', ['ResNet-50','ViT-16', 'ViT-256'])
print(results_df.to_latex())
results_df

100%|██████████| 3/3 [00:01<00:00,  2.32it/s]

\begin{tabular}{lllllllll}
\toprule
{} &  Pretrain &       Arch &             0.25 &              1.0 &             0.25 &              1.0 &             0.25 &              1.0 \\
\midrule
{} &  ImageNet &  ResNet-50 &  0.638 +/- 0.089 &  0.667 +/- 0.070 &  0.696 +/- 0.055 &  0.794 +/- 0.035 &  0.862 +/- 0.030 &  0.951 +/- 0.016 \\
{} &      DINO &     ViT-16 &  0.605 +/- 0.092 &  0.725 +/- 0.083 &  0.622 +/- 0.067 &  0.742 +/- 0.045 &  0.848 +/- 0.032 &  0.899 +/- 0.027 \\
{} &      DINO &    ViT-256 &  0.682 +/- 0.055 &  0.775 +/- 0.042 &  0.773 +/- 0.048 &  0.889 +/- 0.027 &  0.916 +/- 0.022 &  0.974 +/- 0.016 \\
\bottomrule
\end{tabular}






Unnamed: 0,Pretrain,Arch,0.25,1.0,0.25.1,1.0.1,0.25.2,1.0.2
,ImageNet,ResNet-50,0.638 +/- 0.089,0.667 +/- 0.070,0.696 +/- 0.055,0.794 +/- 0.035,0.862 +/- 0.030,0.951 +/- 0.016
,DINO,ViT-16,0.605 +/- 0.092,0.725 +/- 0.083,0.622 +/- 0.067,0.742 +/- 0.045,0.848 +/- 0.032,0.899 +/- 0.027
,DINO,ViT-256,0.682 +/- 0.055,0.775 +/- 0.042,0.773 +/- 0.048,0.889 +/- 0.027,0.916 +/- 0.022,0.974 +/- 0.016
