# Mayo Clinic - STRIP AI Kaggle Competition

Competition link: https://www.kaggle.com/competitions/mayo-clinic-strip-ai/ 

## Overview

### Goal of the Competition

The goal of this competition is to classify the blood clot origins in ischemic stroke. Using whole slide digital pathology images, you'll build a model that differentiates between the two major acute ischemic stroke (AIS) etiology subtypes: cardiac and large artery atherosclerosis.

Your work will enable healthcare providers to better identify the origins of blood clots in deadly strokes, making it easier for physicians to prescribe the best post-stroke therapeutic management and reducing the likelihood of a second stroke.

### Context

Stroke remains the second-leading cause of death worldwide. Each year in the United States, over 700,000 individuals experience an ischemic stroke caused by a blood clot blocking an artery to the brain. A second stroke (23% of total events are recurrent) worsens the chances of the patient’s survival. However, subsequent strokes may be mitigated if physicians can determine stroke etiology, which influences the therapeutic management following stroke events.

During the last decade, mechanical thrombectomy has become the standard of care treatment for acute ischemic stroke from large vessel occlusion. As a result, retrieved clots became amenable to analysis. Healthcare professionals are currently attempting to apply deep learning-based methods to predict ischemic stroke etiology and clot origin. However, unique data formats, image file sizes, as well as the number of available pathology slides create challenges you could lend a hand in solving.

The Mayo Clinic is a nonprofit American academic medical center focused on integrated health care, education, and research. Stroke Thromboembolism Registry of Imaging and Pathology (STRIP) is a uniquely large multicenter project led by Mayo Clinic Neurovascular Lab with the aim of histopathologic characterization of thromboemboli of various etiologies and examining clot composition and its relation to mechanical thrombectomy revascularization.

To decrease the chances of subsequent strokes, the Mayo Clinic Neurovascular Research Laboratory encourages data scientists to improve artificial intelligence-based etiology classification so that physicians are better equipped to prescribe the correct treatment. New computational and artificial intelligence approaches could help save the lives of stroke survivors and help us better understand the world's second-leading cause of death.



## Dataset Description

The dataset for this competition comprises over a thousand high-resolution whole-slide digital pathology images. Each slide depicts a blood clot from a patient that had experienced an acute ischemic stroke.

The slides comprising the training and test sets depict clots with an etiology (that is, origin) known to be either CE (Cardioembolic) or LAA (Large Artery Atherosclerosis). We include a set of supplemental slides with a either an unknown etiology or an etiology other than CE or LAA.

Your task is to classify the etiology (CE or LAA) of the slides in the test set for each patient.

### File and Data Field Descriptions
* **train/** - A folder containing images in the TIFF format to be used as training data.
* **test/** - A folder containing images to be used as test data. The actual test data comprises about 280 images.
* **other/** - A supplemental set of images with a either an unknown etiology or an etiology other than CE or LAA.
* **train.csv** Contains annotations for images in the train/ folder.
    * **image_id** - A unique identifier for this instance having the form {patient_id}_{image_num}. Corresponds to the image {image_id}.tif.
    * **center_id** - Identifies the medical center where the slide was obtained.
    * **patient_id** - Identifies the patient from whom the slide was obtained.
    * **image_num** - Enumerates images of clots obtained from the same patient.
    * **label** - The etiology of the clot, either CE or LAA. This field is the classification target.
* **test.csv** - Annotations for images in the test/ folder. Has the same fields as train.csv excluding label.
* **other.csv** - Annotations for images in the other/ folder. Has the same fields as train.csv. The center_id is unavailable for these images however.
    * **label** - The etiology of the clot, either Unknown or Other.
    * **other_specified** - The specific etiology, when known, in case the etiology is labeled as Other.
* **sample_submission.csv** - A sample submission file in the correct format. See the Evaluation page for more details. Note in particular that you should make one prediction per patient_id, not per image_id.


## Evaluation

Submissions are evaluated using a weighted multi-class logarithmic loss. The overall effect is such that each class is roughly equally important for the final score.

Each image has been labeled with an etiology class, either CE or LAA. For each image, you must submit a probability for each class. The formula is then:

$$\text{Log Loss} = - \left( \frac{\sum^{M}_{i=1} w_{i} \cdot \sum_{j=1}^{N_{i}} \frac{y_{ij}}{N_{i}} \cdot \ln  p_{ij} }{\sum^{M}_{i=1} w_{i}} \right)$$

where $N$ is the number of images in the class set, $M$ is the number of classes, ln is the natural logarithm, $y_{ij}$ is 1 if observation $i$ belongs to class $j$ and 0 otherwise, $p_{ij}$ is the predicted probability that image $i$ belongs to $j$ class .

The submitted probabilities for a given image are not required to sum to one because they are rescaled prior to being scored (each row is divided by the row sum). In order to avoid the extremes of the log function, each predicted probability  is replaced with $\max(\min(p,1-10^{-15}),10^{-15})$.



## Timeline

* July 6, 2022 - Start Date.
* September 28, 2022 - Entry Deadline. You must accept the competition rules before this date in order to compete.
* September 28, 2022 - Team Merger Deadline. This is the last day participants may join or merge teams.
* October 5, 2022 - Final Submission Deadline.



# Training

In [None]:
%pip install torch
%pip install torchvision
%pip install skimage
%pip install cv2
%pip install pyvips

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import skimage.io as io
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import glob
import os
import pyvips
import random

In [None]:
random.seed(6)

## Prepare Images

Load dataset

In [None]:
# Load CSV file with image paths, labels
trainc = pd.read_csv('../input/mayo-clinic-strip-ai/train.csv')
# Convert text labels to binary labels 
trainc['label'] = [0 if i == 'CE' else 1 for i in trainc['label']]
# Remove extremely large files to prevent memory issues
trainc = trainc[trainc['image_id'].isin([x for x in trainc['image_id'] if os.path.getsize(f"../input/mayo-clinic-strip-ai/train/{x}.tif")/1000000000<1])]

In [None]:
# split into train and test datsets
trainc = trainc.sample(frac=1, random_state=12)
indices_or_sections = [int(0.8 * len(trainc)), int((1 - 0.1 - 0.1) * len(trainc))]
train_ds, val_ds, test_ds = np.split(trainc, indices_or_sections)

Create custom dataset that breaks up large image into tiles

In [None]:
class CustomDataset(Dataset):
    def __init__(self, list_IDs, labels, dire="../input/mayo-clinic-strip-ai/train/", dim=(5, 175, 175, 3)):
        '''Initializes dataset.'''
        self.root = dire
        self.file_list = list(list_IDs)
        if type(labels)==type(None):
            self.labels = None
        else:
            self.labels = np.asarray(labels)
        self.dim = dim
    def __len__(self):
        '''Returns length of dataset.'''
        return len(self.file_list)
    def __getitem__(self, idx):
        '''Returns tensor of image tiles and labels if available.'''
        img_path= self.file_list[idx]
        if type(self.labels)!=type(None):
            label = self.labels[idx]
        img = self.__load_make_tiles(f"{self.root}{img_path}.tif" ,self.dim[1], self.dim[0])
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(0, 3, 1, 2)
        if type(self.labels)!=type(None):
            label = torch.tensor([label])
            return img_tensor, label
        else: 
            return img_tensor
    def __get_bg(self, img):
        # cred: https://github.com/libvips/libvips/issues/790#issuecomment-340568993
        # the margin of pixel we extract to get the average edge
        margin = 10 

        # paste black over the centre, take the histogram of the whole image
        square = pyvips.Image.black(img.width - 2 * margin, img.height - 2 * margin)
        hist = img.insert(square, margin, margin).hist_find()

        # zap the 0 column to remove the black square
        onepx = pyvips.Image.black(1, 1)
        hist = hist.insert(onepx, 0, 0) 

        # then the histogram peak is the most common value in each band
        bg = [x.maxpos()[1] for x in hist.gaussblur(1).bandsplit()]
        
        return bg 

    def __add_bg(self, img, embed=False, size=600):
        # cred: https://github.com/libvips/libvips/issues/790#issuecomment-340568993
        bg = self.__get_bg(img)
        # extend image out with that background
        img = img.embed((size - img.width) / 2, (size - img.height) / 2, size, size,
                  extend='background', background=bg)
        return img, bg
    def __load_make_tiles(self, img_path, tile_size, num_tiles):
        '''
        img: np.ndarray with dtype np.uint8 and shape (width, height, channel)
        '''
        # inspired by: https://www.kaggle.com/code/analokamus/a-fast-tile-generation 
        img = pyvips.Image.new_from_file(img_path).resize(1/16)
        patch_size = 200
        n_across = img.width // patch_size
        n_down = img.height // patch_size
        x_max = n_across - 1
        y_max = n_down - 1
        if n_across*n_down <5:
            if max(img.width, img.height) > 600:
                img, bg = self.__add_bg(img, max(img.width, img.height))
            else:
                img, bg = self.__add_bg(img)
            n_across = img.width // patch_size
            n_down = img.height // patch_size
            x_max = n_across - 1
            y_max = n_down - 1
        else:
            bg = self.__get_bg(img)
        n_patches = 0
        patches = []
        notpatches = []
        for y in range(0, n_down):
            for x in range(0, n_across):
                patch = img.crop(x * patch_size, y * patch_size,
                                   patch_size, patch_size)
                for i in range(patch_size):
                    if sum(patch.getpoint(i,i))/3<sum(bg)/3:
                        n_patches += 1
                        patches.append(patch)
                        break
                    else:
                        notpatches.append(patch)
        if len(patches)>=5:
            patches = [x.numpy() for x in random.sample(patches, 5)]
        else:
            for j in range(5-len(patches)):
                patches.append(random.choice(notpatches))
            patches = [x.numpy() for x in patches]
        img = np.moveaxis(np.stack(patches, axis=3), -1, 0)
        return img


In [None]:
train = CustomDataset(train_ds['image_id'], train_ds['label'])
train_data_loader = DataLoader(train, batch_size=4, shuffle=True)
val = CustomDataset(test_ds['image_id'], test_ds['label'])
val_data_loader = DataLoader(val, batch_size=4, shuffle=False)


In [None]:
%pip install torch-summary
%pip install timm

In [None]:
import torch.nn as nn
from timm import create_model
from torchsummary import summary


In [None]:
class Flatten(nn.Module):
    def __init__(self, dim=1):
        super().__init__()
        self.dim = dim

    def forward(self, x): 
        input_shape = x.shape
        output_shape = [input_shape[i] for i in range(self.dim)] + [-1]
        return x.view(*output_shape)
class SimpleMIL(nn.Module):
    # inspired by: https://www.kaggle.com/code/analokamus/a-sample-of-multi-instance-learning-model 
    def __init__(
        self, 
        model_name, 
        num_instances=5, 
        num_classes=2, 
        pretrained=False):
        super().__init__()

        self.num_instances = num_instances
        self.encoder = create_model(
            model_name, 
            pretrained=pretrained, 
            num_classes=num_classes)
        enc_type = self.encoder.__class__.__name__
        feature_dim = self.encoder.get_classifier().in_features
        self.head = nn.Sequential(
            nn.AdaptiveMaxPool2d(1), Flatten(),
            nn.Linear(feature_dim, 256), nn.ReLU(inplace=True), 
            nn.Linear(256, num_classes), nn.Sigmoid()
        )

    def forward(self, x):
        # x: bs x N x C x W x W
        x=x.float()
        bs, _, ch, w, h = x.shape
        x = x.view(bs*self.num_instances, ch, w, h) # x: N bs x C x W x W
        x = self.encoder.forward_features(x) # x: N bs x C' x W' x W'

        # Concat and pool
        bs2, ch2, w2, h2 = x.shape
        x = x.view(-1, self.num_instances, ch2, w2, h2).permute(0, 2, 1, 3, 4)\
            .contiguous().view(bs, ch2, self.num_instances*w2, h2) # x: bs x C' x N W'' x W''
        x = self.head(x)

        return x

In [None]:
model = torch.load('../input/strip-model-train/model.pt')
model.to(torch.device('cuda'))

In [None]:
model.train()


In [None]:
# inspired by: https://www.kaggle.com/competitions/mayo-clinic-strip-ai/discussion/353305
class WeightedLogLoss(nn.Module):
    def __init__(self, weights=torch.tensor([0.5,0.5])):
        super(WeightedLogLoss, self).__init__();
        self.weights=weights
    
    def forward(self, probs, target):
        log_probs = torch.log(probs)
        res = 0
        for c in torch.unique_consecutive(target):
            class_log_probs = log_probs[target == c][:, c]
            class_weight = self.weights[c]
            res += class_weight * class_log_probs.mean()
        return torch.tensor(- (res / torch.sum(self.weights)))
    
    def __call__(self, pred, target):
        loss = self.forward(pred, target)
        return loss


In [None]:
loss_func = WeightedLogLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)  
epochs=1
final_losses=[]

for i in range(epochs):

    i= i+1
    loss = 0
    for j, data in enumerate(train_data_loader): 
        optimizer.zero_grad(set_to_none=True)
        inputs, labels = data
        labels = np.reshape(labels, (labels.shape[0], labels.shape[-1])).to(device)
        inputs = inputs.to(device)
        y_pred = model(inputs)
        bloss = loss_func(y_pred, labels.squeeze(1))
        bloss.requires_grad = True
        bloss.backward()
        optimizer.step()
        loss += bloss.item() 
    final_losses.append(loss/1000)
    print("Epoch number: {} and the loss : {}".format(i,loss/1000))


# Inference

In [None]:
subc = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")


In [None]:
sub = CustomDataset(subc['image_id'], None, dire="../input/mayo-clinic-strip-ai/test/")
sub_data_loader = DataLoader(sub, batch_size=1, shuffle=False)


In [None]:
model.eval()
ce = []
laa = []
with torch.no_grad():
    for j, data in enumerate(sub_data_loader):
        inputs = data.to(device)
        predb= model(inputs)
        for i in predb:
            if i[0] < pow(10,-6) or i[0] >= 0.9999994: 
                ce.append(pow(10,-15)) if i[0] < pow(10,-6) else ce.append(0.999999)
            else:
                ce.append(i[0])
            if i[1] < pow(10,-6) or i[1] >= 0.9999994: 
                laa.append(pow(10,-15)) if i[1] < pow(10,-6) else laa.append(0.999999)
            else:
                laa.append(i[1])

sub_df = pd.DataFrame(list(zip(subc['patient_id'], ce, laa)), columns=['patient_id', 'CE', 'LAA'])
sub_df = sub_df.groupby("patient_id").mean()
assert len(sub_df) == len(set(subc['patient_id'])), f"# of rows ({len(sub_df)}) != # of patient ids ({len(set(subc['patient_id']))})"
sub_df = sub_df[["CE", "LAA"]].round(6)#.set_index("patient_id")
sub_df.to_csv("./submission.csv")

In [None]:
sub_df

# Results

* Public score (loss): **0.68415**
* **53rd** out of **888**
* **Top 6%** 
* **Bronze** Medal (top 100)