# CLAM

NOTE: Some of the descriptions or images are cited from: https://github.com/mahmoodlab/CLAM

![img](https://github.com/mahmoodlab/CLAM/raw/master/docs/CLAM2.jpg)

## TL;DR:

+ CLAM is a high-throughput and interpretable method for data efficient whole slide image (WSI) classification using slide-level labels without any ROI extraction or patch-level annotations, and is capable of handling multi-class subtyping problems. Tested on three different WSI datasets, trained models adapt to independent test cohorts of WSI resections and biopsies as well as smartphone microscopy images (photomicrographs).
+ paper: https://arxiv.org/abs/2004.09666

## How to apply CLAM on the STRIP AI dataset ?

+ I prepared four notebooks for pre-process, train and inference:

### pre-process

+ (1) image generation: https://www.kaggle.com/code/fx6300/clam-strip-ai-image-generation
+ <b>&gt; THIS NOTEBOOK &lt;</b> (2) feature extraction: https://www.kaggle.com/code/fx6300/clam-strip-ai-feature-extraction

### train

+ (3) train: https://www.kaggle.com/code/fx6300/clam-strip-ai-train

### inference

+ (4) inference: https://www.kaggle.com/code/fx6300/clam-strip-ai-inference

## How to visualize the attention generated by CLAM ?

+ I prepared an example:
  + https://www.kaggle.com/fx6300/clam-strip-ai-attention-heatmap

## NOTE

+ The source code from CLAM (https://github.com/mahmoodlab/CLAM) is licensed under GPLv3 and available for non-commercial academic purposes.

In [None]:
import os
import gc
import cv2
import time
import random
import string
import joblib
import numpy as np 
import pandas as pd 
import torch
from torch import nn
import seaborn as sns
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from tqdm import trange
import warnings
import torch.nn.functional as F
import h5py
import torch.utils.model_zoo as model_zoo
warnings.filterwarnings("ignore")

In [None]:
debug = False
generate_new = True
train_df = pd.read_csv("../input/mayo-clinic-strip-ai/train.csv")
test_df = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")
dirs = ["../input/mayo-clinic-strip-ai/train/", "../input/mayo-clinic-strip-ai/test/"]
IMG_HEIGHT = 512 * 8
IMG_WIDTH = 512 * 8

In [None]:
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

class Bottleneck_Baseline(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck_Baseline, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet_Baseline(nn.Module):

    def __init__(self, block, layers):
        self.inplanes = 64
        super(ResNet_Baseline, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1) 

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x

def resnet50_baseline(pretrained=False):
    """Constructs a Modified ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3])
    if pretrained:
        model = load_pretrained_weights(model, 'resnet50')
    return model

def load_pretrained_weights(model, name):
    pretrained_dict = model_zoo.load_url(model_urls[name])
    model.load_state_dict(pretrained_dict, strict=False)
    return model

In [None]:
class ImgDataset(Dataset):
    def __init__(self, df):
        self.df = df 
        self.train = 'label' in df.columns
    def __len__(self):
        return len(self.df) * 64
    
    def __getitem__(self, _index):
        paths = ["../input/4096-tiles-v4/test-4096-tiles-v4/4096-tiles-v4/", "../input/4096-tiles-v4/train-4096-tiles-v4/4096-tiles-v4/"]
        index = _index // 64
        pos = _index % 64
        image_id = self.df.iloc[index].image_id
        image = cv2.imread(paths[self.train] + image_id + f"_{pos}" + ".jpg").transpose(2, 0, 1)
        label = None
        if(self.train):
            label = {"CE" : 0, "LAA": 1}[self.df.iloc[index].label]
        return image, label, image_id, pos

In [None]:
def apply_model(model, train_loader, output_dir):
    s = nn.Softmax(dim=1)
    model.cuda()
    model.eval()
    for item in tqdm(train_loader, leave=False):
        images = item[0].cuda().float()
        classes = item[1].cuda().long()
        image_ids = item[2]
        image_poses = item[3]
        with torch.no_grad():
            output = model(images)
        for i in range(output.shape[0]):
            f = h5py.File(f"{output_dir}/{image_ids[i]}.h5", 'a')
            f.create_dataset(f"{image_poses[i]}", data=output[i].cpu(), dtype=np.float32)
            f.close()
        del images, classes, image_ids, image_poses
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
for i, train_idx in enumerate(train_df.index):   
    model = resnet50_baseline(pretrained=True)
    train = train_df.iloc[[train_idx]]
    batch_size = 32
    train_loader = DataLoader(
        ImgDataset(train), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=1
    )
    output_dir = "./my_features-v4"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    apply_model(model, train_loader, output_dir)
    del model, train, train_loader
    gc.collect()
    torch.cuda.empty_cache()