### Dependencies

In [5]:
import os
import pickle
import sys
sys.path.append('../')

import h5py
import numpy as np
import pandas as pd
from PIL import Image
import umap
import umap.plot
from tqdm import tqdm

import torch
from torch.utils.data.dataset import Dataset
import torch.multiprocessing
import torchvision
from torchvision import transforms
from pl_bolts.models.self_supervised import resnets
from pl_bolts.utils.semi_supervised import Identity

from DINO.vision_transformer import *
from Classification.models.resnet_custom import resnet50_baseline

device = torch.device('cuda:0')
torch.multiprocessing.set_sharing_strategy('file_system')

### How To Use
1. Download CRC-100K, BCSS, and BreastPathQ to the "patch_datasets" path
2. Place pretrain models in the "assets_dir" path
3. Run the "create_embeddings" function to save extract features for each patch dataset
4. Pre-extracted features for each patch dataset are available in "embeddings_patch_library" on Google Drive

### Helper Functions for: Image Normalization, Patch Dataset Loading, and Patch Embedding Saving

In [3]:
def eval_transforms(pretrained=False):
    if pretrained:
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    else:
        mean, std = (0.5,0.5,0.5), (0.5,0.5,0.5)
    trnsfrms_val = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
    return trnsfrms_val

def torchvision_ssl_encoder(name: str, pretrained: bool = False, return_all_feature_maps: bool = False):
    pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)
    pretrained_model.fc = Identity()
    return pretrained_model

In [4]:
class CSVDataset(Dataset):    
    def __init__(self, dataroot, csv_path, transforms_eval):
        self.csv = pd.read_csv(csv_path)
        self.csv['img_path'] = dataroot+self.csv['slide'].astype(str) + "_" + self.csv['rid'].astype(str) + '.tif'
        self.transforms = transforms_eval
        
    def __getitem__(self, index):
        img = Image.open(self.csv['img_path'][index])
        return self.transforms(img), self.csv['y'][index]
    
    def __len__(self):
        return self.csv.shape[0]

class CSVDataset_BCSS(Dataset):    
    def __init__(self, dataset_csv, is_train=1, transforms_eval=eval_transforms()):
        self.csv = dataset_csv
        self.csv = self.csv[self.csv['train']==is_train]
        self.transforms = transforms_eval   
        
    def __getitem__(self, index):
        img = Image.open(self.csv.index[index])
        return self.transforms(img), self.csv.iloc[index]['label']
    
    def __len__(self):
        return self.csv.shape[0]

In [5]:
def save_embeddings(model, fname, dataloader, dataset=None, is_imagefolder=False, 
                    save_patches=False, sprite_dim=128):
    embeddings, labels = [], []
    patches = []

    for batch, target in tqdm(dataloader):
        if save_patches:
            for img in batch:
                patches.append(tensor2im(input_image=img).resize(sprite_dim))
        
        with torch.no_grad():
            batch = batch.to(device)
            embeddings.append(model(batch).detach().cpu().numpy())
            labels.append(target.numpy())
            
            
    embeddings = np.vstack(embeddings)
    labels = np.vstack(labels).squeeze()
    
    if is_imagefolder:
        id2label = dict(map(reversed, dataset.class_to_idx.items()))
        labels = np.array(list(map(id2label.get, labels.ravel())))

    asset_dict = {'embeddings': embeddings, 'labels': labels}
    if save_patches:
        asset_dict.update({'patches': patches})
    with open('%s.pkl' % (fname), 'wb') as handle:
        pickle.dump(asset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

        
def create_UMAP(library_path, enc_name, dataset, n=50, d=0.2):
    import matplotlib.pyplot as plt
    import umap
    
    path = os.path.join(library_path, '%s_%s.pkl' % (dataset, enc_name))
    with open(path, 'rb') as handle:
        asset_dict = pickle.load(handle)
        embeddings, labels = asset_dict['embeddings'], asset_dict['labels']
    mapper = umap.UMAP(n_neighbors=n, min_dist=d).fit(embeddings)
    fig = plt.figure(figsize=(10, 10), dpi=100)
    umap.plot.points(mapper, labels=labels, width=600, height=600)
    plt.tight_layout()
    plt.savefig(os.path.join(library_path, 'UMAPs', '%s_%s_umap_n%d_d%0.2f.jpg' % (dataset, enc_name, n, d)))


def create_embeddings(embeddings_dir, enc_name, dataset, save_patches=False, sprite_dim=128, 
                      patch_datasets='path/to/patch/datasets', assets_dir ='../ckpts/pretrain/',
                      disentangle=-1, stage=-1):
    print(enc_name)
    if enc_name == 'resnet50_trunc':
        model = resnet50_baseline(pretrained=True)
        eval_t = eval_transforms(pretrained=True)
        print("Loading ResNet50")
    elif 'dino' in enc_name:
        ckpt_path = os.path.join(assets_dir, enc_name+'.pt')
        assert os.path.isfile(ckpt_path)
        model = vit_small(patch_size=16)
        state_dict = torch.load(ckpt_path, map_location="cpu")['teacher']
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        print(missing_keys, unexpected_keys)
        eval_t = eval_transforms(pretrained=False)
        print("Loading DINO")
    else:
        pass

    model = model.to(device)
    model.eval()

    if 'dino' in enc_name:
        _model = model
        if stage == -1:
            model = _model
        else:
            model = lambda x: torch.cat([x[:, 0] for x in _model.get_intermediate_layers(x, stage)], dim=-1)

    _dis = ''
    if stage != -1:
        _stage = '_s%d' % stage
    else:
        _stage = ''
    
    if dataset == 'kather100k':
        ### Train
        dataroot = os.path.join(patch_datasets, 'NCT-CRC-HE-100K/')
        dataset = torchvision.datasets.ImageFolder(dataroot, transform=eval_t)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False, num_workers=4)
        fname = os.path.join(embeddings_dir, 'kather100k_train_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=fname, dataloader=dataloader, dataset=dataset,
                        save_patches=save_patches, sprite_dim=sprite_dim, is_imagefolder=True)
        
        ### Test
        dataroot = os.path.join(patch_datasets, 'CRC-VAL-HE-7K/')
        dataset = torchvision.datasets.ImageFolder(dataroot, transform=eval_t)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
        fname = os.path.join(embeddings_dir, 'kather100k_val_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=fname, dataloader=dataloader, dataset=dataset,
                        save_patches=save_patches, sprite_dim=sprite_dim, is_imagefolder=True)

    elif dataset == 'kather100knonorm':
        ### Train
        dataroot = os.path.join(patch_datasets, 'NCT-CRC-HE-100K-NONORM/')
        dataset = torchvision.datasets.ImageFolder(dataroot, transform=eval_t)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False, num_workers=4)
        fname = os.path.join(embeddings_dir, 'kather100knonorm_train_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=fname, dataloader=dataloader, dataset=dataset,
                        save_patches=save_patches, sprite_dim=sprite_dim, is_imagefolder=True)
        
        ### Test
        dataroot = os.path.join(patch_datasets, 'CRC-VAL-HE-7K/')
        dataset = torchvision.datasets.ImageFolder(dataroot, transform=eval_t)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
        fname = os.path.join(embeddings_dir, 'kather100knonorm_val_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=fname, dataloader=dataloader, dataset=dataset,
                        save_patches=save_patches, sprite_dim=sprite_dim, is_imagefolder=True)

    elif dataset == 'breastpathq':
        train_dataroot = os.path.join(patch_datasets, 'BreastPathQ/breastpathq/datasets/train/')
        val_dataroot = os.path.join(patch_datasets, 'BreastPathQ/breastpathq/datasets/validation/')
        train_csv = os.path.join(patch_datasets, 'BreastPathQ/breastpathq/datasets/train_labels.csv')
        val_csv = os.path.join(patch_datasets, 'BreastPathQ/breastpathq/datasets/val_labels.csv')

        train_dataset = CSVDataset(dataroot=train_dataroot, csv_path=train_csv, transforms_eval=eval_t)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4)
        val_dataset = CSVDataset(dataroot=val_dataroot, csv_path=val_csv, transforms_eval=eval_t)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
        
        train_fname = os.path.join(embeddings_dir, 'breastq_train_%s%s%s' % (enc_name, _dis, _stage))
        val_fname = os.path.join(embeddings_dir, 'breastq_val_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=train_fname, dataloader=train_dataloader, 
                        save_patches=save_patches, sprite_dim=sprite_dim)
        save_embeddings(model=model, fname=val_fname, dataloader=val_dataloader, 
                        save_patches=save_patches, sprite_dim=sprite_dim)

    
    elif dataset == 'bcss':
        dataroot = os.path.join(patch_datasets, 'BCSS/40x/patches/All/')
        csv_path = os.path.join(patch_datasets, 'BCSS/40x/patches/summary.csv')
        
        dataset_csv = pd.read_csv(csv_path, sep=' ')['filename,train'].str.split(',', expand=True).astype(int)
        dataset_csv.columns = ['label', 'train']
        dataset_csv = dataset_csv[dataset_csv['label'].isin([0,1,2,3])]
        dataset_csv.index = [os.path.join(dataroot, fname+'.png') for fname in dataset_csv.index]

        train_dataset = CSVDataset_BCSS(dataset_csv=dataset_csv, is_train=1, transforms_eval=eval_t)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=1)
        val_dataset = CSVDataset_BCSS(dataset_csv=dataset_csv, is_train=0, transforms_eval=eval_t)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)

        train_fname = os.path.join(embeddings_dir, 'bcss_train_%s%s%s' % (enc_name, _dis, _stage))
        val_fname = os.path.join(embeddings_dir, 'bcss_val_%s%s%s' % (enc_name, _dis, _stage))
        save_embeddings(model=model, fname=train_fname, dataloader=train_dataloader, 
                        save_patches=save_patches, sprite_dim=sprite_dim)
        save_embeddings(model=model, fname=val_fname, dataloader=val_dataloader, 
                        save_patches=save_patches, sprite_dim=sprite_dim)

### Embedding Library Construction

In [None]:
library_path = './embeddings_patch_library/'
os.makedirs(library_path, exist_ok=True)

models = ['resnet50_trunc', 'vits_tcga_brca_dino', 'vits_tcga_pancancer_dino']

for enc_name in models:
    create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='kather100knonorm')
    if  enc_name == 'vits_tcga_pancancer_dino':
        create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='kather100knonorm', stage=4)

for enc_name in models:
    create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='kather100k')
    if  enc_name == 'vits_tcga_pancancer_dino':
        create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='kather100k', stage=4)
    
for enc_name in models:
    create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='bcss')
    if  enc_name == 'vits_tcga_pancancer_dino':
        create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='bcss', stage=4)
    
for enc_name in models:
    create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='breastpathq')
    if  enc_name == 'vits_tcga_pancancer_dino':
        create_embeddings(embeddings_dir=library_path, enc_name=enc_name, dataset='breastpathq', stage=4)