In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

from IPython.display import HTML

# Competition goal

* Classifiy the etiology (origin) of a blood clot shown in a tissue slice of a patient that has experienced an acute ischemic stroke. 
* It's a **binary classification task** with labels CE (cardioembolic) or LAA (Large Artery Atherosclerosis)
* There are supplemental slides with a either an **unknown etiology or an etiology other than CE or LAA**! 

Aha! ... :-O Do you know the difference between CE and LAA or what is meant by an acute ischemic stroke? ... Neither do I! ;-) 

In [None]:
HTML('<iframe width="800" height="444" src="https://www.youtube.com/embed/abxcrAvw-O0" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>')

## Table of contents

1. [Prepare to start](#prepare)
2. [Exploratory analysis - EDA](#eda)
    * [What do we know about the training data?](#train_data)
3. [The nightmare of overfitting](#nightmare)
4. [Tell me what do you see?](#explain)
    * [Choosing a model](#choose_model)
5. [Explaining predictions with LIME](#lime)


# Prepare to start <a class="anchor" id="prepare"></a>

## Loading packages

In [None]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image
import cv2
from skimage.segmentation import mark_boundaries

import tensorflow as tf

from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

import seaborn as sns
sns.set()

from lime import lime_image

import shap

## Loading data

In [None]:
train = pd.read_csv("../input/mayo-clinic-strip-ai/train.csv")

# Exploratory analysis <a class="anchor" id="eda"></a>

## What do we know about our training data? <a class="anchor" id="train_data"></a>


In [None]:
train.head()

In [None]:
train.shape

* The **image_id** consists of two different parts: *{patient_id}_{image_number}*
* The **center_id** defines the medical center or institute where the image was taken
* **patient_id** is the id of the patient
* **image_num** is a counter for slices of a patient
* **label** is our target
    * CE: cardioembolic
    * LAA: Large artery atherosclerosis
    

In [None]:
train.patient_id.nunique()

632 patients in train.

In [None]:
plt.figure(figsize=(10,5))
sns.countplot(train.groupby("patient_id").image_num.size(), palette="Greens_r")
plt.xlabel("Number of images per patient")
plt.title("Max image number per patient in train");

We can clearly see that most patients only have one slice in the training data and only a few show more than one. Maybe these are hard cases where it's more difficult to see the etiology?! Or does it depend on the medical center?

In [None]:
train.groupby("patient_id").center_id.nunique().max()

Ok, all patients belong to one center. There is no patient that shows images that were taken in different medial centers. How many do we have and how many patients do we have per center?

In [None]:
train.groupby("center_id").patient_id.size().sort_values(ascending=False)

Definitely not balanced.

Let's take a look at the target distribution! :-)

In [None]:
sns.countplot(train.label, palette="Reds_r")

Ok, it's an imbalanced classification problem. 

In [None]:
train.label.value_counts() / train.shape[0]

Ok we know that we have some patients with more than one slice. Is it possible they show more than one label per patient?

In [None]:
train.groupby("patient_id").label.nunique().max()

No! :-)

In [None]:
test = pd.read_csv("../input/mayo-clinic-strip-ai/test.csv")
test["image_path"] = test["image_id"].apply(lambda x: os.path.join("../input/mayo-clinic-strip-ai/test/", x+".tif"))
test.head()

## The nightmare of overfitting ;-) <a class="anchor" id="nightmare"></a>

Can't wait anymore! Let's start to discover the images! 

In [None]:
from PIL import Image

Image.MAX_IMAGE_PIXELS = None

In [None]:
example = "../input/mayo-clinic-strip-ai/train/006388_0.tif"
img = Image.open(example)
img.size

Uff!

In [None]:
x1 = img.size[0]
x2 = img.size[1]
sc = x2/x1

In [None]:
factor = 1/10
resized_img = img.resize((int(x1*factor), int(sc*x1*factor)))

In [None]:
resized_img.size

In [None]:
resized_img.rotate(90, expand=True)

Oh this will be a nightmare for validation.... such a complex feature space... I can't see the blood clot, do you? How should we prevent our model from overfitting? Seems hopeless. Now, imagine that medical centers might also use different methods or protocols to obtain these slices... ohoh! The number of patients given such a high feature space is low and the public test data seems to be quite small as well. Hmm.

# Tell me, what do you see? <a class="anchor" id="explain"></a>

Ok, so far I can't imagine that any model will be able to see something useful in this complex feature space... but let's better try to understand and explain what our models "see". For this purpose I like to use and learn more about some machine learning explainability tools for computer vision.

## What is meant by Machine Learning Explainability?

## Choosing a model <a class="anchor" id="choose_model"></a>


Let's pick one of the best scoring public models so far with the hope that it has found an interesting and useful signal. At the moment it's this one:

* https://www.kaggle.com/code/realneuralnetwork/cnn-strip-ai-inference/notebook

In [None]:
class ImgDataset(Dataset):
    
    def __init__(self, df):
        self.df = df 
        self.train = 'label' in df.columns
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        paths = ["../input/jpg-images-strip-ai/test/", "../input/jpg-images-strip-ai/train/"]
        try:
            image = cv2.imread(paths[self.train] + self.df.iloc[index].image_id + ".jpg")
        except:
            image = np.zeros((512,512,3), np.uint8)
        label = 0
        try:
            if len(image.shape) == 5:
                image = image.squeeze().transpose(1, 2, 0)
            image = cv2.resize(image, (512, 512)).transpose(2, 0, 1)
        except:
            image = np.zeros((3, 512, 512))
        if(self.train):
            label = {"CE" : 0, "LAA": 1}[self.df.iloc[index].label]
        patient_id = self.df.iloc[index].patient_id
        return image, label, patient_id
    
    
def predict(model, dataloader):
    model.cuda()
    model.eval()
    dataloader = dataloader
    outputs = []
    s = nn.Softmax(dim=1)
    ids = []
    for item in tqdm(dataloader, leave=False):
        patient_id = item[2][0]
        try:
            images = item[0].cuda().float()
            ids.append(patient_id)
            output = model(images)
            outputs.append(s(output.cpu()[:,:2])[0].detach().numpy())
        except:
            ids.append(patient_id)
            outputs.append(s(torch.tensor([[1, 1]]).float())[0].detach().numpy())
    return np.array(outputs), ids

In [None]:
model = torch.jit.load('../input/strip-ai-models-to-explain/kabir_ivan_cnn_strip_ai_v1.pth')

In [None]:
batch_size = 1
test_loader = DataLoader(
    ImgDataset(test.iloc[0:4]), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=1
)

In [None]:
anss, ids = predict(model, test_loader)

In [None]:
anss

In [None]:
ids

In [None]:
prob = pd.DataFrame({"CE" : anss[:,0], "LAA" : anss[:,1], "id" : ids}).groupby("id").mean()
submission = pd.read_csv("../input/mayo-clinic-strip-ai/sample_submission.csv")
submission.CE = prob.CE.to_list()
submission.LAA = prob.LAA.to_list()
submission.to_csv("submission.csv", index = False)

# Explaining predictions with LIME <a class="anchor" id="lime"></a>

LIME stands for Local Interpretable Model-Agnoistic Explanations. You can read the paper published 2016 [here](https://arxiv.org/pdf/1602.04938.pdf). I know that there exist newer tools but I think it can be a good starting point to get started with the topic. ;-) Here is a great video to introduce LIME:

In [None]:
HTML('<iframe width="708" height="400" src="https://www.youtube.com/embed/hUnRCxnydCc" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>')

### Insights

Let's collect a few key messages of the video and the paper:

* LIME explains the predictions of any classifer (model) by treating it as a black box.
* It trains an interpretable model locally around the prediction.
* It helps to decide whether to trust a prediction or a model.
* We need to keep in mind that the performance measured by your evaluation metric might not suite to the performance of our model on real-world data (pulic/private test data here) or even with good performance it can happen that your predictions doesn't make sense. 
* The explanation we get with LIME has to be simple enough for a human to understand regardless how many features were used by a model to make its prediction.
* For an image classifier an interpretable representation can be a binary vector that indicates a presence of absence of a contiguous patch of similar pixels (super-pixel).



Ok let's start! 

In [None]:
def batch_predict(images):
    s = nn.Softmax(dim=1)
    images = images.transpose((0,3,1,2))
    images = torch.tensor(images).cuda().float()
    outputs = model(images)
    preds = s(outputs.cpu()[:,:2]).detach().numpy()
    return preds

In [None]:
batch_size = 1
test_loader = DataLoader(
    ImgDataset(test.iloc[0:2]), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=1
)
test_iterator = iter(test_loader)

As we have only two classes we can set top_labels with the value 2:

In [None]:
model.cuda()
model.eval()
batch = next(test_iterator)
patient_ids = batch[2]
image = batch[0].cpu().detach()
lm_image = image.squeeze().numpy().transpose((1,2,0)) 

explainer = lime_image.LimeImageExplainer()


explanation = explainer.explain_instance(lm_image, 
                                     batch_predict, 
                                     top_labels=2, 
                                     hide_color=0, batch_size=1)
output = model(image.cuda().float())
s = nn.Softmax(dim=1)
probas = s(output[:,:2])[0].cpu().detach().numpy()

In [None]:
probas

We can see that the probability of the first label (CE) is higher than for the second one (LAA). To understand which parts of the image caused the model to make this prediction, we can take a look at the superpixels for the top_label (CE in this case). The tissues of these images are very detailed and I expect that the signals of the origin of the stroke are very tiny. For this reason I like to choose a high number of superpixels given by the num_features attribute. We only want to see which superpixels contribute to the top_label prediction (CE) and hide the rest of the image. And this is what we get:

In [None]:
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=100, hide_rest=True)
img_boundry1 = mark_boundaries(temp/255.0, mask)


images = batch[0].cpu().detach().numpy()
print(images.shape)
images = images.transpose((0,2,3,1))

fig, ax = plt.subplots(1,2,figsize=(20,10))
ax[0].imshow(images[0])

ax[0].set_title("Patient {}".format(patient_ids[0]))
ax[1].imshow(img_boundry1)

for n in range(2):
    ax[n].axis("off");

# Conclusion 

* We can't trust our model! Do you see it? We can clearly see that empty superpixels contributed to the top_label prediction. This doesn't make sense! 
* There are also a few regions that belong to tissue but so far it's not clear why. It could be the color or the way the image was streched or distorted during preprocessing... who knows?! 

In my opinion we should try to understand the problem better. Can we really see differenes in these tissues given the origin of the blood clot after an acute ischemic stroke? How do medical experts perform this task? Do they have more information available when they need to make a choice?