# CLAM

NOTE: Some of the descriptions or images are cited from: https://github.com/mahmoodlab/CLAM

![img](https://github.com/mahmoodlab/CLAM/raw/master/docs/CLAM2.jpg)

## TL;DR:

+ CLAM is a high-throughput and interpretable method for data efficient whole slide image (WSI) classification using slide-level labels without any ROI extraction or patch-level annotations, and is capable of handling multi-class subtyping problems. Tested on three different WSI datasets, trained models adapt to independent test cohorts of WSI resections and biopsies as well as smartphone microscopy images (photomicrographs).
+ paper: https://arxiv.org/abs/2004.09666

## How to apply CLAM on the STRIP AI dataset ?

+ I prepared four notebooks for pre-process, train and inference:

### pre-process

+ (1) image generation: https://www.kaggle.com/code/fx6300/clam-strip-ai-image-generation
+ (2) feature extraction: https://www.kaggle.com/code/fx6300/clam-strip-ai-feature-extraction

### train

+ <b>&gt; THIS NOTEBOOK &lt;</b> (3) train: https://www.kaggle.com/code/fx6300/clam-strip-ai-train

### inference

+ (4) inference: https://www.kaggle.com/code/fx6300/clam-strip-ai-inference

## How to visualize the attention generated by CLAM ?

+ I prepared an example:
  + https://www.kaggle.com/fx6300/clam-strip-ai-attention-heatmap

## NOTE

+ The source code from CLAM (https://github.com/mahmoodlab/CLAM) is licensed under GPLv3 and available for non-commercial academic purposes.

In [None]:
!conda install ../input/how-to-use-pyvips-offline/*.tar.bz2
!git clone https://github.com/oval-group/smooth-topk.git
!cd smooth-topk && python setup.py install

In [None]:
import sys
sys.path.append("/opt/conda/lib/python3.7/site-packages/topk-1.0-py3.7.egg")
import os
import gc
import cv2
import time
import random
import string
import joblib
import tifffile
import numpy as np 
import pandas as pd 
import torch
from torch import nn
from torchvision import models
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm
from tqdm import trange
from torch.optim import lr_scheduler
import warnings
import tempfile
from PIL import Image
import pyvips
import torch.nn.functional as F
import h5py
from topk.svm import SmoothTop1SVM
warnings.filterwarnings("ignore")

In [None]:
debug = False
train_df = pd.read_csv("../input/mayo-clinic-strip-ai/train.csv")
test_df = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")
dirs = ["../input/mayo-clinic-strip-ai/train/", "../input/mayo-clinic-strip-ai/test/"]
TILE_SIZE = 256
IMG_HEIGHT = TILE_SIZE * 6
IMG_WIDTH = TILE_SIZE * 6
TARGET_FOLD = [0,1,2,3,4]#[0,1,2,3,4]

In [None]:
class ImgDataset(Dataset):
    def __init__(self, df, data_dir, **kwargs):
        self.df = df
        self.data_dir = data_dir
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, index):
        image_id = self.df.iloc[index].image_id
        label = {"CE":0,"LAA":1}[self.df.iloc[index].label]
        full_path = f"{self.data_dir}/{self.df.iloc[index].image_id}.h5"
        with h5py.File(full_path,'r') as hdf5_file:
            features = torch.stack([torch.tensor(hdf5_file[str(i)]) for i in range(64)]).view(64, 1024)
            coords = torch.tensor([i for i in range(64)]).view(64)
        return features, label, coords

In [None]:
"""
Attention Network without Gating (2 fc layers)
args:
    L: input feature dimension
    D: hidden layer dimension
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes 
"""

def initialize_weights(module):
        for m in module.modules():
                if isinstance(m, nn.Linear):
                        nn.init.xavier_normal_(m.weight)
                        m.bias.data.zero_()
                
                elif isinstance(m, nn.BatchNorm1d):
                        nn.init.constant_(m.weight, 1)
                        nn.init.constant_(m.bias, 0)

class Attn_Net(nn.Module):

    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
        super(Attn_Net, self).__init__()
        self.module = [
            nn.Linear(L, D),
            nn.Tanh()]

        if dropout:
            self.module.append(nn.Dropout(0.25))

        self.module.append(nn.Linear(D, n_classes))
        
        self.module = nn.Sequential(*self.module)
    
    def forward(self, x):
        return self.module(x), x # N x n_classes

"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
    L: input feature dimension
    D: hidden layer dimension
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes 
"""
class Attn_Net_Gated(nn.Module):
    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()]
        
        self.attention_b = [nn.Linear(L, D),
                            nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        return A, x

"""
args:
    gate: whether to use gated attention network
    size_arg: config for network size
    dropout: whether to use dropout
    k_sample: number of positive/neg patches to sample for instance-level training
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes 
    instance_loss_fn: loss function to supervise instance-level training
    subtyping: whether it's a subtyping problem
"""
class CLAM_SB(nn.Module):
    def __init__(self, gate = True, size_arg = "small", dropout = True, k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False):
        super(CLAM_SB, self).__init__()
        self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        if dropout:
            fc.append(nn.Dropout(0.25))
        if gate:
            attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        else:
            attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        self.classifiers = nn.Linear(size[1], n_classes)
        instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping

        initialize_weights(self)

    def relocate(self):
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.attention_net = self.attention_net.to(device)
        self.classifiers = self.classifiers.to(device)
        self.instance_classifiers = self.instance_classifiers.to(device)
    
    @staticmethod
    def create_positive_targets(length, device):
        return torch.full((length, ), 1, device=device).long()
    @staticmethod
    def create_negative_targets(length, device):
        return torch.full((length, ), 0, device=device).long()
    
    #instance-level evaluation for in-the-class attention branch
    def inst_eval(self, A, h, classifier): 
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        top_p_ids = torch.topk(A, self.k_sample)[1][-1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        top_n_ids = torch.topk(-A, self.k_sample, dim=1)[1][-1]
        top_n = torch.index_select(h, dim=0, index=top_n_ids)
        p_targets = self.create_positive_targets(self.k_sample, device)
        n_targets = self.create_negative_targets(self.k_sample, device)

        all_targets = torch.cat([p_targets, n_targets], dim=0)
        all_instances = torch.cat([top_p, top_n], dim=0)
        logits = classifier(all_instances)
        all_preds = torch.topk(logits, 1, dim = 1)[1].squeeze(1)
        instance_loss = self.instance_loss_fn(logits, all_targets)
        return instance_loss, all_preds, all_targets
    
    #instance-level evaluation for out-of-the-class attention branch
    def inst_eval_out(self, A, h, classifier):
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        top_p_ids = torch.topk(A, self.k_sample)[1][-1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        p_targets = self.create_negative_targets(self.k_sample, device)
        logits = classifier(top_p)
        p_preds = torch.topk(logits, 1, dim = 1)[1].squeeze(1)
        instance_loss = self.instance_loss_fn(logits, p_targets)
        return instance_loss, p_preds, p_targets

    def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
        device = h.device
        A, h = self.attention_net(h)  # NxK        
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # softmax over N
        total_inst_loss = torch.tensor(0.0).cuda()
        if instance_eval:
            all_preds = []
            all_targets = []
            inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze() #binarize label
            for i in range(len(self.instance_classifiers)):
                inst_label = inst_labels[i].item()
                classifier = self.instance_classifiers[i]
                if inst_label == 1: #in-the-class:
                    instance_loss, preds, targets = self.inst_eval(A, h, classifier)
                    all_preds.extend(preds.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                else: #out-of-the-class
                    if self.subtyping:
                        instance_loss, preds, targets = self.inst_eval_out(A, h, classifier)
                        all_preds.extend(preds.cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())
                    else:
                        continue
                total_inst_loss += instance_loss

            if self.subtyping:
                total_inst_loss /= len(self.instance_classifiers)
        M = torch.mm(A, h) 
        logits = self.classifiers(M)
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        Y_prob = F.softmax(logits, dim = 1)
        # print(total_inst_loss)
        return logits, Y_prob, Y_hat, total_inst_loss, A_raw

In [None]:
def train_model(model, train_loader, val_loader, criterions, optimizer, num_epochs, fold):
    best_loss = 10000.0
    best_acc = 0
    for epoch in range(num_epochs):
        val_loss = 0
        model.cuda()
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            epoch_mcll = 0.0
            epoch_acc = 0
            
            y_hat = []
            y = []
            for item in tqdm(train_loader if phase == "train" else val_loader, leave=False):
                images = item[0][0].cuda()
                classes = item[1].cuda().long()
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    logits, Y_prob, Y_hat, total_inst_loss, _ = model(images, label=classes[0], instance_eval=True)
                    loss1 = criterions[0](Y_prob, F.one_hot(classes, num_classes=2).cuda().float())
                    #loss2 = criterions[1](Y_prob, classes)
                    loss3 = total_inst_loss
                    loss = loss1 + loss3

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    y_hat += Y_hat
                    y += classes.data
                    epoch_loss += loss.item() * len(Y_hat)
                    # print(Y_hat)
                    # print(classes.data)
                    epoch_acc += torch.sum(Y_hat[0] == classes.data) / len(Y_hat[0])
            #epoch_mcll = weighted_multi_class_log_loss(s(torch.stack(y_hat).cpu()[:,:2]).detach(), torch.stack(y))
            data_size = len(train_loader if phase == "train" else val_loader)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size
            #epoch_mcll = epoch_mcll / data_size
            #epoch_mcll_val = epoch_mcll if phase == "val" else 100.0
            if phase == 'val':
                val_loss = epoch_loss

            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
        if epoch_acc > best_acc:
            traced = torch.jit.trace(model.cpu(), torch.rand(64, 1024))
            traced.save('model_fold{}.pth'.format(fold))
            best_acc = epoch_acc
            del traced
        elif abs(epoch_acc - best_acc) < 1e6 and best_loss > val_loss:
            traced = torch.jit.trace(model.cpu(), torch.rand(64, 1024))
            traced.save('model_fold{}.pth'.format(fold))
            best_loss = val_loss
            del traced
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
def weighted_multi_class_log_loss(y_hat, y, w):
    # print(y)
    # print(torch.clamp(torch.sum(y[:,0]), min=1e-15))
    # print(torch.log(torch.clamp(y_hat, 1e-15, 1.0 - 1e-15)))
    # print(y * torch.log(torch.clamp(y_hat, 1e-15, 1.0 - 1e-15)))
    # print(0.5 * y * torch.log(torch.clamp(y_hat, 1e-15, 1.0 - 1e-15)))
    a = torch.log(torch.clamp(y_hat, 1e-15, 1.0 - 1e-15)).cuda()
    b = torch.tensor([torch.clamp(torch.sum(y[:,0]), min=1e-15), torch.clamp(torch.sum(y[:,1]), min=1e-15)]).cuda()
    return torch.sum(-torch.sum(w * y * a * 1/b))
weighted_multi_class_log_loss(torch.tensor([[0.5,0.5], [0.8,0.2], [0.0,1.0], [0.0,1.0]]).cuda(), F.one_hot(torch.tensor([0, 1, 0, 0]), 2).cuda(), torch.tensor([0.5,0.5]).cuda())

In [None]:
class WeightedMultiClassLogLoss(torch.nn.Module):
    def __init__(self, weights):
        super(WeightedMultiClassLogLoss, self).__init__()
        self.weights = weights

    def forward(self, inputs, targets):
        return weighted_multi_class_log_loss(inputs, targets, self.weights)

In [None]:
skf = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 42)

train09_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42, stratify = train_df.label)

for fold, (train_idx, val_idx) in enumerate(skf.split(train09_df.index, train09_df.label)):
    if not fold in TARGET_FOLD:
        continue
    loss_fn = SmoothTop1SVM(n_classes = 2).cuda()
    model = CLAM_SB(instance_loss_fn = loss_fn, subtyping = True, k_sample = 4)
    #model = timm.create_model("dm_nfnet_f0")
    train = train09_df.iloc[train_idx, :]
    val = train09_df.iloc[val_idx, :]
    ce_size = len(train[train['label'] == "CE"])
    laa_size = len(train[train['label'] == "LAA"])
    pre_weights = [(ce_size + laa_size) / ce_size, (ce_size + laa_size) / laa_size] 
    weights = torch.tensor(pre_weights / np.sum(pre_weights)).cuda()
    batch_size = 1
    train_loader = DataLoader(
        ImgDataset(train, "../input/my-features-v4/my_features-v4"), 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=1
    )
    val_loader = DataLoader(
        ImgDataset(val, "../input/my-features-v4/my_features-v4"), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=1
    )
    criterions = [WeightedMultiClassLogLoss(weights), nn.CrossEntropyLoss()]
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    train_model(model, train_loader, val_loader, criterions, optimizer, 50, fold)
    del model, train, val, train_loader, val_loader, criterions, optimizer
    gc.collect()
    torch.cuda.empty_cache()