# Batch Effect detection class

In [None]:
!pip -q install ./../../BatchDetect

## Reading metadata

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from tqdm import tqdm

import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

from histaugan.model import EfficientHistAuGAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# create metadata dataframe from clini_table and folder structure
clini_table = pd.read_excel('/lustre/groups/peng/datasets/histology_data/clini_tables/TCGA-CRC-DX_CLINI.xlsx')

# metadata with columns: file, label (MSI-H), submission site
base_dir = Path('/lustre/groups/shared/users/peng_marr/BatchDetect/')
patch_list = list(base_dir.glob('BatchDetectTCGA/*/TCGA*/*.jpeg'))
print('Number of patches:', len(patch_list))

submission_site = [patch.parent.parent.name for patch in patch_list]
label = [clini_table.isMSIH[clini_table['PATIENT'] == patch.parent.name[:12]].item() for patch in patch_list]
metadata = pd.DataFrame(list(zip(patch_list, label, submission_site)), columns=['file', 'label', 'dataset'])

In [None]:
np.unique(np.array(submission_site), return_counts=True)

In [None]:
metadata

## Features

In [None]:
class DatasetGenerator(Dataset):

    def __init__(self, metadata, transform=transforms.ToTensor()):
        self.metadata = metadata.copy().reset_index(drop = True)
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        ## get image and label
        file_path = self.metadata.loc[idx,"file"]
        image= Image.open(file_path)

        label = self.metadata.loc[idx,"label"]

        if self.transform:
            image = self.transform(image) 
            
        return image.float(), label, idx

In [None]:
X_features_path = base_dir / 'ctranspath_features_efficient_histaugan*2.csv'

if X_features_path.exists():
    X_features = pd.read_csv(X_features_path)
else: 
    feature_dim = 768
    X_features = pd.DataFrame(index = metadata.index, columns = ["X" + str(i+1) for i in range(feature_dim)])
    
    # load feature extractor
    from swin_transformer import swin_tiny_patch4_window7_224, ConvStem

    feature_extractor = swin_tiny_patch4_window7_224(embed_layer=ConvStem, pretrained=False)
    feature_extractor.head = nn.Identity()

    ctranspath = torch.load('/home/haicu/sophia.wagner/models/ctranspath.pth')
    feature_extractor.load_state_dict(ctranspath['model'], strict=True)
    feature_extractor.to(device)
    feature_extractor.eval();
    
    # load efficient histaugan model
    checkpoint_dir = Path('/lustre/groups/peng/workspace/sophia.wagner/logs/histaugan_lightning/checkpoints')
    run = 'l1_a_cc+correct_adv_cls+attr_VAE+128'
    model_name = 'Efficient-HistAuGAN-epoch=01-l1_cc_loss_val=0.72.ckpt'

    model = EfficientHistAuGAN.load_from_checkpoint(checkpoint_dir / run / model_name)
    model = model.to(device)
    model.eval();
    opts = model.opts
    
    # load dataset
    # transform = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    # ])
    transform = transforms.Compose([
        # transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    dataset = DatasetGenerator(metadata, transform)
    dataloader =  DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
    
    # forward pass through the feature extractor
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            for image, _, idx in tqdm(dataloader): 
                image = image.to(device)
                lowres = F.interpolate(image, size=(128, 128))
                z_content, (mu, _) = model.encoder(lowres)
                z_random = torch.randn_like(mu) * 2
                image = model.generator(image, z_content, z_random)
                image = F.interpolate(image, size=(224, 224))
                features = feature_extractor(image).cpu().numpy().reshape((len(idx), feature_dim))
                X_features.loc[idx,:] = features
                
    X_features.to_csv(X_features_path)

    del dataset
    del dataloader
    del feature_extractor

## Let's see if there is a batch effect in the data

In [None]:
from batchdetect.batchdetect import BatchDetect

bd = BatchDetect(metadata.loc[:,["label","dataset"]], X_features)

# visualizations

In [None]:
bd.low_dim_visualization("pca")

In [None]:
bd.low_dim_visualization("tsne")

In [None]:
bd.low_dim_visualization("umap")

# Anova test of principal components vs. labels

In [None]:
bd.prince_plot()

## classification test of  RF vs a random classifier

In [None]:
bd.classification_test(scorer="f1_macro")