# CLAM

NOTE: Some of the descriptions or images are cited from: https://github.com/mahmoodlab/CLAM

![img](https://github.com/mahmoodlab/CLAM/raw/master/docs/CLAM2.jpg)

## TL;DR:

+ CLAM is a high-throughput and interpretable method for data efficient whole slide image (WSI) classification using slide-level labels without any ROI extraction or patch-level annotations, and is capable of handling multi-class subtyping problems. Tested on three different WSI datasets, trained models adapt to independent test cohorts of WSI resections and biopsies as well as smartphone microscopy images (photomicrographs).
+ paper: https://arxiv.org/abs/2004.09666

## How to apply CLAM on the STRIP AI dataset ?

+ I prepared four notebooks for pre-process, train and inference:

### pre-process

+ (1) image generation: https://www.kaggle.com/code/fx6300/clam-strip-ai-image-generation
+ (2) feature extraction: https://www.kaggle.com/code/fx6300/clam-strip-ai-feature-extraction

### train

+ (3) train: https://www.kaggle.com/code/fx6300/clam-strip-ai-train

### inference

+ <b>&gt; THIS NOTEBOOK &lt;</b> (4) inference: https://www.kaggle.com/code/fx6300/clam-strip-ai-inference

## How to visualize the attention generated by CLAM ?

+ I prepared an example:
  + https://www.kaggle.com/fx6300/clam-strip-ai-attention-heatmap

## NOTE

+ The source code from CLAM (https://github.com/mahmoodlab/CLAM) is licensed under GPLv3 and available for non-commercial academic purposes.

In [None]:
!conda install ../input/how-to-use-pyvips-offline/*.tar.bz2

In [None]:
!yes | pip uninstall opencv-python
!pip install ../input/opencv-contrib/opencv_contrib_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install ../input/tiling030/tiling-0.3.0-py2.py3-none-any.whl

!pip install --no-deps ../input/my-rembg/asyncer-*.whl
!pip install --no-deps ../input/my-rembg/filetype-*.whl
!pip install --no-deps ../input/my-rembg/onnxruntime-*.whl
!pip install --no-deps ../input/my-rembg/PyMatting-*.whl
!pip install --no-deps ../input/my-rembg/watchdog-*.whl
!pip install --no-deps ../input/my-rembg/click-*.whl
!pip install --no-deps ../input/my-rembg/fastapi-*.whl
!pip install --no-deps ../input/my-rembg/opencv_python_headless-*.whl
!pip install --no-deps ../input/my-rembg/uvicorn-*.whl

#!pip install --no-deps ../input/my-rembg/Pillow-*.whl #skip
#!pip install --no-deps ../input/my-rembg/scikit_image-*.whl #skip

!cp -rp /kaggle/input/my-rembg/gdown-4.5.1/gdown-4.5.1/ /kaggle/working
!cd /kaggle/working/gdown-4.5.1 && python3 setup.py install
#!python3 -m pip install -e file:///kaggle/working/gdown-4.5.1 --no-index -f file:///kaggle/input/my-rembg/ --prefix /opt/conda/lib/python3.7/site-packages
!rm -r /kaggle/working/gdown-4.5.1

!cp -rp /kaggle/input/my-rembg/python-multipart-0.0.5/python-multipart-0.0.5/ /kaggle/working
!cd /kaggle/working/python-multipart-0.0.5 && python3 setup.py install
#!python3 -m pip install -e file:///kaggle/working/python-multipart-0.0.5 --no-index -f file:///kaggle/input/my-rembg/ --prefix /opt/conda/lib/python3.7/site-packages
!rm -r /kaggle/working/python-multipart-0.0.5

!pip install --no-deps ../input/my-rembg/rembg-2.0.25-py3-none-any.whl

In [None]:
!pip show gdown
!pip show python-multipart

In [None]:
!test -d ~/.u2net || mkdir ~/.u2net
!cp -f ../input/u2net-onnx/u2net.onnx ~/.u2net/

In [None]:
import sys
sys.path.append("/opt/conda/lib/python3.7/site-packages/gdown-4.5.1-py3.7.egg")
sys.path.append("/opt/conda/lib/python3.7/site-packages/python_multipart-0.0.5-py3.7.egg")
import os
import gc
import cv2
import copy
import time
import random
import string
import joblib
import tifffile
import numpy as np 
import pandas as pd 
import torch
from torch import nn
import seaborn as sns
from torchvision import models
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from torch.optim import lr_scheduler
import warnings
import openslide
from openslide.deepzoom import DeepZoomGenerator
import tempfile
from PIL import Image
import pyvips
from scipy.special import expit
from tiling import ConstStrideTiles
import h5py
from rembg import remove
warnings.filterwarnings("ignore")
gc.enable()

In [None]:
debug = False
generate_new = True
test_df = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")
dirs = ["../input/mayo-clinic-strip-ai/train/", "../input/mayo-clinic-strip-ai/test/"]
test_df

In [None]:
def tiling(img):
    h, w = img.height, img.width
    tiles = ConstStrideTiles(image_size=(h, w), tile_size=(512, 512), stride=(512, 512), 
                             origin=(0, 0),
                             scale=1.0,
                             include_nodata=True)
    print("Number of tiles: %i" % len(tiles))
    imgs = []
    for extent, out_size in tiles:
        x, y, width, height = extent
        data = img.crop(x, y, width, height)
        imgs.append(data)
    return imgs

In [None]:
def resize_image(img: Image, width) -> Image:
    return pyvips2pil(img.thumbnail_image(width))

In [None]:
def resize_image_with_canvas(img: Image, width) -> Image:
    canvas = Image.new("RGB", (width, width), "white")
    img2 = pyvips2pil(img.thumbnail_image(width))
    canvas.paste(img2, (0, 0))
    return canvas

In [None]:
def bg_trim(_img: Image) -> Image:
    img = np.array(_img)
    yidx = []
    xidx = []
    for y in range(img.shape[0]):
        if sum(sum(255 - img[y,:])) < 10:
            yidx += [y]
    for x in range(img.shape[1]):
        if sum(sum(255 - img[:,x])) < 10:
            xidx += [x]
    img = np.delete(img, xidx, axis=1)
    img = np.delete(img, yidx, axis=0)
    return Image.fromarray(img)

In [None]:
def pyvips2pil(img: pyvips.Image) -> Image:
    # RGB -> BGR
    imgBGR = cv2.cvtColor(img.numpy(), cv2.COLOR_RGB2BGR)
    return Image.fromarray(imgBGR)

In [None]:
def pil2pyvips(img: Image) -> pyvips.Image:
    # BGR -> RGB
    imgRGB = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    return pyvips.Image.new_from_array(imgRGB)

In [None]:
def my_remove(img: Image) -> Image:
    img = remove(img) # PIL -> PIL
    B, G, R, A = img.split()
    alpha = np.array(A) / 255
    R = (255 * (1 - alpha) + np.array(R) * alpha).astype(np.uint8)
    G = (255 * (1 - alpha) + np.array(G) * alpha).astype(np.uint8)
    B = (255 * (1 - alpha) + np.array(B) * alpha).astype(np.uint8)
    return cv2.merge((B, G, R))

In [None]:
class ImgDataset(Dataset):
    def __init__(self, df):
        self.df = df 
        self.train = 'label' in df.columns
    def __len__(self):
        return len(self.df) * 64
    
    def __getitem__(self, _index):
        paths = ["./test/4096-tiles-v4/", "./train/4096-tiles-v4/"]
        index = _index // 64
        pos = _index % 64
        image_id = self.df.iloc[index].image_id
        image = cv2.imread(paths[self.train] + image_id + f"_{pos}" + ".jpg").transpose(2, 0, 1)
        return image, image_id, pos

In [None]:
class FeatureDataset(Dataset):
    def __init__(self, df, data_dir, **kwargs):
        self.df = df
        self.data_dir = data_dir
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, index):
        image_id = self.df.iloc[index].image_id
        patient_id = self.df.iloc[index].patient_id
        full_path = f"{self.data_dir}/{self.df.iloc[index].image_id}.h5"
        with h5py.File(full_path,'r') as hdf5_file:
            features = torch.stack([torch.tensor(hdf5_file[str(i)]) for i in range(64)]).view(64, 1024)
            coords = torch.tensor([i for i in range(64)]).view(64)
        return features, coords, patient_id

In [None]:
def apply_model(model, test_loader, output_dir):
    s = nn.Softmax(dim=1)
    model.cuda()
    model.eval()
    for item in tqdm(test_loader, leave=False):
        images = item[0].cuda().float()
        image_ids = item[1]
        image_poses = item[2]
        with torch.no_grad():
            output = model(images)
        for i in range(output.shape[0]):
            f = h5py.File(f"{output_dir}/{image_ids[i]}.h5", 'a')
            f.create_dataset(f"{image_poses[i]}", data=output[i].cpu(), dtype=np.float32)
            f.close()
        del images, image_ids, image_poses
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
def predict(model, dataloader):
    model.cuda()
    model.eval()
    outputs = []
    attentions = []
    s = nn.Softmax(dim=1)
    ids = []
    for item in tqdm(dataloader, leave=False):
        patient_id = item[2][0]
        ids.append(patient_id)
        try:
            images = item[0][0].cuda().float()  
            _, output, _, _, attention = model(images)
            outputs.append(s(output.cpu()[:,:2])[0].detach().numpy())
            attentions.append(attention[0].view(8, 8).cpu().detach().numpy())
            del output, images
        except Exception as e:
            print(e)
            outputs.append(s(torch.tensor([[1, 1]]).float())[0].detach().numpy())
            attentions.append(torch.ones(8, 8).detach().cpu().numpy())
        gc.collect()
        torch.cuda.empty_cache()
    return np.array(outputs), ids, attentions

In [None]:
# image generation
if not os.path.exists("./test/"):
    os.mkdir("./test/")
if not os.path.exists("./test/4096-tiles-v4"):
    os.mkdir("./test/4096-tiles-v4")
for i in tqdm(range(test_df.shape[0])):
    img_id = test_df.iloc[i].image_id
    img = pyvips.Image.new_from_file(dirs[1] + img_id + ".tif", access='sequential')
    img = resize_image(img, 4096) # pyvips -> PIL
    img = my_remove(img) # PIL -> PIL
    img = bg_trim(img) # PIL -> PIL
    img = resize_image_with_canvas(pil2pyvips(img), 4096) # PIL -> pyvips -> PIL
    cv2.imwrite(f"./test/4096-tiles-v4/{img_id}.jpg", np.array(img))
    imgs = tiling(pyvips.Image.new_from_array(img)) # retain BGR order
    for i, _img in enumerate(imgs):
        cv2.imwrite(f"./test/4096-tiles-v4/{img_id}_{i}.jpg", _img.numpy()) # BGR
    del imgs
    gc.collect()

In [None]:
# feature extraction
for test_idx in test_df.index:
    resnet50_baseline = torch.jit.load("../input/resnet50-baselinepth/resnet50_baseline.pth")
    test = test_df.iloc[[test_idx]]
    batch_size = 32
    test_loader = DataLoader(
        ImgDataset(test), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=1
    )
    output_dir = "./my_features"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    apply_model(resnet50_baseline, test_loader, output_dir)
    del resnet50_baseline, test, test_loader
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
model_paths = [
    '../input/clambaseline/model_fold0.pth',
]

prob = pd.DataFrame()
df_list = []
for model_id, model_path in enumerate(model_paths):
    torch.cuda.empty_cache()
    model = torch.jit.load(model_path)
    torch.cuda.empty_cache()
    
    batch_size = 1
    test_loader = DataLoader(
        FeatureDataset(test_df, "./my_features"), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=1
    )
    anss, ids, attentions = predict(model, test_loader)
    del model, test_loader
    gc.collect()
    torch.cuda.empty_cache()
    mydf = pd.DataFrame({"CE" : anss[:,0], "LAA" : anss[:,1], "id" : ids}).groupby("id").mean()
    df_list += [mydf]

In [None]:
laa_probs = 1.0 - sum([1.0/len(df_list) * df_list[i].CE for i in range(len(df_list))])
submission = pd.DataFrame({"patient_id": df_list[0].index, "CE": 1.0 - np.asarray(laa_probs.to_list()), "LAA": laa_probs.to_list()})
submission

In [None]:
submission.to_csv("submission.csv", index = False)