# CLAM

NOTE: Some of the descriptions or images are cited from: https://github.com/mahmoodlab/CLAM

![img](https://github.com/mahmoodlab/CLAM/raw/master/docs/CLAM2.jpg)

## TL;DR:

+ CLAM is a high-throughput and interpretable method for data efficient whole slide image (WSI) classification using slide-level labels without any ROI extraction or patch-level annotations, and is capable of handling multi-class subtyping problems. Tested on three different WSI datasets, trained models adapt to independent test cohorts of WSI resections and biopsies as well as smartphone microscopy images (photomicrographs).
+ paper: https://arxiv.org/abs/2004.09666

## How to apply CLAM on the STRIP AI dataset ?

+ I prepared four notebooks for pre-process, train and inference:

### pre-process

+ <b>&gt; THIS NOTEBOOK &lt;</b> (1) image generation: https://www.kaggle.com/code/fx6300/clam-strip-ai-image-generation
+ (2) feature extraction: https://www.kaggle.com/code/fx6300/clam-strip-ai-feature-extraction

### train

+ (3) train: https://www.kaggle.com/code/fx6300/clam-strip-ai-train

### inference

+ (4) inference: https://www.kaggle.com/code/fx6300/clam-strip-ai-inference

## How to visualize the attention generated by CLAM ?

+ I prepared an example:
  + https://www.kaggle.com/fx6300/clam-strip-ai-attention-heatmap

## NOTE

+ The source code from CLAM (https://github.com/mahmoodlab/CLAM) is licensed under GPLv3 and available for non-commercial academic purposes.

In [None]:
!conda install ../input/how-to-use-pyvips-offline/*.tar.bz2
!pip install tiling
!pip install rembg

In [None]:
import os
import gc
import cv2
import time
import random
import joblib
import numpy as np 
import pandas as pd 
from tqdm.notebook import tqdm
import warnings
from PIL import Image
import pyvips
import skimage.exposure
from tiling import ConstStrideTiles
from rembg import remove
warnings.filterwarnings("ignore")

In [None]:
debug = False
generate_new = True
train_df = pd.read_csv("../input/mayo-clinic-strip-ai/train.csv").head(10 if debug else 1000)
test_df = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")
dirs = ["../input/mayo-clinic-strip-ai/train/", "../input/mayo-clinic-strip-ai/test/"]

In [None]:
def tiling(img):
    h, w = img.height, img.width
    tiles = ConstStrideTiles(image_size=(h, w), tile_size=(512, 512), stride=(512, 512), 
                             origin=(0, 0),
                             scale=1.0,
                             include_nodata=True)
    print("Number of tiles: %i" % len(tiles))
    imgs = []
    for extent, out_size in tiles:
        x, y, width, height = extent
        data = img.crop(x, y, width, height)
        imgs.append(data)
    return imgs

In [None]:
def resize_image(img: Image, width) -> Image:
    return pyvips2pil(img.thumbnail_image(width))

In [None]:
def resize_image_with_canvas(img: Image, width) -> Image:
    canvas = Image.new("RGB", (width, width), "white")
    img2 = pyvips2pil(img.thumbnail_image(width))
    canvas.paste(img2, (0, 0))
    return canvas

In [None]:
def bg_trim(_img: Image) -> Image:
    img = np.array(_img)
    yidx = []
    xidx = []
    for y in range(img.shape[0]):
        if sum(sum(255 - img[y,:])) < 10:
            yidx += [y]
    for x in range(img.shape[1]):
        if sum(sum(255 - img[:,x])) < 10:
            xidx += [x]
    img = np.delete(img, xidx, axis=1)
    img = np.delete(img, yidx, axis=0)
    return Image.fromarray(img)

In [None]:
def pyvips2pil(img: pyvips.Image) -> Image:
    # RGB -> BGR
    imgBGR = cv2.cvtColor(img.numpy(), cv2.COLOR_RGB2BGR)
    return Image.fromarray(imgBGR)

In [None]:
def pil2pyvips(img: Image) -> pyvips.Image:
    # BGR -> RGB
    imgRGB = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    return pyvips.Image.new_from_array(imgRGB)

In [None]:
def my_remove(img: Image) -> Image:
    img = remove(img) # PIL -> PIL
    B, G, R, A = img.split()
    alpha = np.array(A) / 255
    R = (255 * (1 - alpha) + np.array(R) * alpha).astype(np.uint8)
    G = (255 * (1 - alpha) + np.array(G) * alpha).astype(np.uint8)
    B = (255 * (1 - alpha) + np.array(B) * alpha).astype(np.uint8)
    return cv2.merge((B, G, R))

In [None]:
if(generate_new):
    if not os.path.exists("./train/"):
        os.mkdir("./train/")
    if not os.path.exists("./test/"):
        os.mkdir("./test/")
    if not os.path.exists("./train/4096-tiles-v4"):
        os.mkdir("./train/4096-tiles-v4")
    if not os.path.exists("./test/4096-tiles-v4"):
        os.mkdir("./test/4096-tiles-v4")
    for i in tqdm(range(test_df.shape[0])):
        img_id = test_df.iloc[i].image_id
        img = pyvips.Image.new_from_file(dirs[1] + img_id + ".tif", access='sequential')
        img = resize_image(img, 4096) # pyvips -> PIL
        img = my_remove(img) # PIL -> PIL
        img = bg_trim(img) # PIL -> PIL
        img = resize_image_with_canvas(pil2pyvips(img), 4096) # PIL -> pyvips -> PIL
        cv2.imwrite(f"./test/4096-tiles-v4/{img_id}.jpg", np.array(img))
        imgs = tiling(pyvips.Image.new_from_array(img)) # retain BGR order
        for i, _img in enumerate(imgs):
            cv2.imwrite(f"./test/4096-tiles-v4/{img_id}_{i}.jpg", _img.numpy()) # BGR
        del imgs
        gc.collect()
    for i in tqdm(range(train_df.shape[0])):
        img_id = train_df.iloc[i].image_id
        img = pyvips.Image.new_from_file(dirs[0] + img_id + ".tif", access='sequential')
        img = resize_image(img, 4096) # pyvips -> PIL
        img = my_remove(img) # PIL -> PIL
        img = bg_trim(img) # PIL -> PIL
        img = resize_image_with_canvas(pil2pyvips(img), 4096) # PIL -> pyvips -> PIL
        cv2.imwrite(f"./train/4096-tiles-v4/{img_id}.jpg", np.array(img))
        imgs = tiling(pyvips.Image.new_from_array(img)) # retain BGR order
        for i, _img in enumerate(imgs):
            cv2.imwrite(f"./train/4096-tiles-v4/{img_id}_{i}.jpg", _img.numpy()) # BGR
        del imgs
        gc.collect()