In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch import  nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm
import cv2
from torchvision import transforms, utils
import skimage.io as skio

In [None]:
_metadata = pd.read_csv("/kaggle/input/data/Data_Entry_2017.csv")
_metadata.loc[_metadata["Finding Labels"].str.contains("No Finding")]

In [None]:
!pip install einops

In [None]:
import einops
def show_image(image):
    plt.figure()
    plt.imshow(image)
    plt.show()

def rearrange_tensor(img):
    return einops.rearrange(img, "c w h -> w h c")

def unnormalize_tensor(img):   
    img = img.detach().numpy()
    img = einops.rearrange(img, "c w h -> w h c")

    mean = np.array([0.485, 0.456, 0.406])
    std =  np.array([0.229, 0.224, 0.225])
    img = (img * std) + mean
    return img

def unnormalize_img(img):   
    mean = np.array([0.485, 0.456, 0.406])
    std =  np.array([0.229, 0.224, 0.225])
    img = (img * std) + mean
    return img


def normalize_image(img):
    # normalize all images, this is necessary prepreocessing of inputs for vgg network
    normalize = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize(256), 
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), 
                                          (0.229, 0.224, 0.225))
                    ])
    img = normalize(img)
#     img = img[:3,:,:].unsqueeze(0)
    return img

## Define dataset, dataloader, and model

In [None]:
class MulticlassChestDataset(Dataset):
    def __init__(self, val = False):
        _metadata = pd.read_csv("input/data/Data_Entry_2017.csv")
#         _metadata = _metadata.sample(frac = 1)
        num_train = int(.6 * len(_metadata))
        num_val = int(.4 * len(_metadata))
        self.metadata = _metadata
        self.num_train = num_train
        self.num_val = num_val
        self.data = {}
        self.val = val
        self.classes = {"Atelectasis" : 0, "Cardiomegaly" : 1, "Effusion" : 2, "Infiltration" : 3, "Mass" : 4, "Nodule" : 5, "Pneumonia" : 6, "Pneumothorax" : 7, 
                        "Consolidation" : 8, "Edema" : 9, "Emphysema" : 10, 
                        "Fibrosis" : 11, "Pleural_Thickening" : 12, "Hernia" : 13, "No Finding" : 14}
    def __len__(self): 
        if not self.val:
            return self.num_train
        return self.num_val
    def __getitem__(self, idx):
        if self.val:
            idx += self.num_train
        if idx in self.data:
            return self.data[idx]
        file_name = self.metadata.iloc[idx]["Image Index"]
        row = self.metadata.iloc[idx].name
        folder_num = 1
        if row >= 4999:
            folder_num = (row - 4999) // 10000 + 2
        image_file_path = "input/data/images_" + str(folder_num).zfill(3) + "/images/" + file_name
        img = skio.imread(image_file_path)
        if len(img.shape) >= 3:
            img = img[:, :, 0]
        label = self.metadata.iloc[idx]["Finding Labels"]
        splitted = label.split("|")
        vectorized_label = np.zeros((15, 1), dtype = "double")
        for diagnosis in splitted:
            vectorized_label[self.classes[diagnosis]] = 1.0
#         tensor_label = torch.from_numpy(vectorized_label)
#         tensor_label = tensor_label.unsqueeze(2)
        # turn 1-channel image to 3-channel image
        img = np.stack((img,)*3, axis=-1)
        # crop & normalize
        img = normalize_image(img)
        return img, vectorized_label

In [None]:
batch_size = 16
train_dataset = MulticlassChestDataset()
train_data_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
validation_dataset = MulticlassChestDataset(val = True)
validation_data_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
!pip -q install vit_pytorch linformer

In [None]:
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
    dim=128,
    seq_len=50,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=15,
    transformer=efficient_transformer,
    channels=3,
).to(device)

## Train model

In [None]:
epochs = 6
lr = 3e-4
gamma = 0.1

In [None]:
# loss function
# TODO: try triplet or siamese losses. Used for typical CBIR tasks
criterion = nn.BCEWithLogitsLoss() # nn.TripletMarginLoss(margin=0.1) 
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
import tqdm
checkpoint_file = "model_weights"
sigmoid = nn.Sigmoid()
epoch = 0

# if os.path.exists(checkpoint_file):
#     print('gathering info from checkpoint file')
#     checkpoint = torch.load(checkpoint_file)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     criterion.load_state_dict(checkpoint['optimizer_state_dict'])
#     epoch = checkpoint['epoch'] + 1

for epoch in range(epoch, epochs + epoch):
    epoch_loss = 0
    epoch_accuracy = 0
    
    # Wrap in a progress bar to display progress during training.
    progress_bar = tqdm.tqdm(train_data_loader)
    
    for i, inputs in enumerate(progress_bar):
        data, label = inputs
        data = data.to(device)
        label = label.to(device)
        optimizer.zero_grad()

        output = model(data)
        
        
        loss = criterion(sigmoid(output), label)
        
        loss.backward()
        optimizer.step()
        progress_bar.set_description(f"Loss: {loss}")
        
        output = output.detach().cpu().numpy()
        
        output = np.where(output < 0.5, 0.0, 1.0)
        label = label.detach().cpu().numpy()

#         correct_classifications = np.where((output == label)& (output == 1.0), 1.0, 0.0)
#         i = 3
#         acc = np.sum(correct_classifications[i] == label[i]) #!= output[5])

        acc_arr = []
        for i in range(len(output)):
            num_elements = len(output[0])
            acc_arr.append((num_elements - np.sum(output[i] != label[i])) / num_elements)
            #acc2 += (num_elements - np.sum(output[i] != label[i])) / num_elements
            
        #acc = acc / len(output)

        acc = np.mean(acc_arr) #(output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_data_loader)
        epoch_loss += loss / len(train_data_loader)
        
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in tqdm.tqdm(validation_data_loader):
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(sigmoid(val_output), label)
            
            acc_arr = []
            for i in range(len(val_output)):
                num_elements = len(val_output[0])
                acc_arr.append((num_elements - np.sum(val_output[i] != label[i])) / num_elements)
                #acc2 += (num_elements - np.sum(output[i] != label[i])) / num_elements

            #acc = acc / len(output)

            acc = np.mean(acc_arr)
            epoch_val_accuracy += acc / len(validation_data_loader)
            epoch_val_loss += val_loss / len(validation_data_loader)
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': criterion.state_dict(),
                'acc': epoch_accuracy,
                'loss': epoch_loss,
                'validation_acc' : epoch_val_accuracy,
                'validation_loss': epoch_val_loss,
                }, checkpoint_file)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [None]:
checkpoint = torch.load("/kaggle/input/model-weights/model_weights")
model.load_state_dict(checkpoint['model_state_dict'])
criterion.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [None]:
from IPython.display import FileLink
FileLink(r'model_weights')

In [None]:
from einops import rearrange, repeat
def get_latent(vit_model, img):
    x = vit_model.to_patch_embedding(img)
    b, n, _ = x.shape

    cls_tokens = repeat(vit_model.cls_token, '() n d -> b n d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    x += vit_model.pos_embedding[:, :(n + 1)]
    x = vit_model.transformer(x)

    x = x.mean(dim = 1) if vit_model.pool == 'mean' else x[:, 0]

    return vit_model.to_latent(x)


In [None]:
unshuffled_train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle = True)

In [None]:
embeddings = np.zeros((len(train_dataset), 128))
labels = []
model.eval()
imgs = []
for i, sample in enumerate(tqdm.tqdm(unshuffled_train_dataloader)):
    img, label = sample
    labels.append(int(label))
    img = img.to(device)
    label = label.to(device)
    embedding = get_latent(model, img)
    embeddings[i] = (embedding.cpu().detach().numpy())
    original_image = unnormalize_img(einops.rearrange(img.cpu().detach().numpy(), "b c w h -> w h (b c)"))
    imgs.append(original_image)
embeddings = np.array(embeddings)

In [None]:
np.save("imgs.npy", imgs)

from IPython.display import FileLink
FileLink(r'imgs.npy')

In [None]:
embeddings.shape

In [None]:
from sklearn.manifold import TSNE

# The default of 1,000 iterations gives fine results, but I'm training for longer just to eke
# out some marginal improvements. NB: This takes almost an hour!
tsne = TSNE()

low_dim_embeddings = tsne.fit_transform(embeddings)

In [None]:
cmap = {0: 'red', 1: 'blue'}
c = labels
plt.scatter(low_dim_embeddings[:, 0], low_dim_embeddings[:, 1], c=c, cmap = "Accent")
plt.colorbar()

In [None]:
from numpy import random
from scipy.spatial import distance

def closest_node(node, nodes):
    return np.argsort(distance.cdist(node, nodes))[:, :5][0]

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

def image_grid(img_arr):
    fig = plt.figure(figsize=(20., 20.))
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(1, 5),  # creates 2x2 grid of axes
                     axes_pad=0.1,  # pad between axes
                     )

    for ax, im in zip(grid, img_arr):
        ax.imshow(im)

    plt.show()

In [None]:
indices = closest_node(np.array([[40,40]]), low_dim_embeddings)
top_images = np.array(imgs)[indices]
print("CARDIOMEGALY IMAGES:")
print(np.array(labels)[indices])
image_grid(top_images)

In [None]:
indices = closest_node(np.array([[20,-100]]), low_dim_embeddings)
top_images = np.array(imgs)[indices]
print("NO FINDING IMAGES:")
print(np.array(labels)[indices])
image_grid(top_images)

In [None]:
# trying to create a more user-friendly UI for looking at images at given points. Not working yet. 

import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax  = fig.add_subplot(111)
cmap = {0: 'red', 1: 'blue'}
c = labels
x = low_dim_embeddings[:, 0]
y = low_dim_embeddings[:, 1]
ax.scatter(low_dim_embeddings[:, 0], low_dim_embeddings[:, 1], c=c, cmap = "Accent")
# ax.colorbar()

def onclick(event):
    ix, iy = event.xdata, event.ydata
    print("I clicked at x={0:5.2f}, y={1:5.2f}".format(ix,iy))

    # Calculate, based on the axis extent, a reasonable distance 
    # from the actual point in which the click has to occur (in this case 5%)
    ax = plt.gca()
    dx = 0.05 * (ax.get_xlim()[1] - ax.get_xlim()[0])
    dy = 0.05 * (ax.get_ylim()[1] - ax.get_ylim()[0])

    # Check for every point if the click was close enough:
    for i in range(len(x)):
        if(x[i] > ix-dx and x[i] < ix+dx and y[i] > iy-dy and y[i] < iy+dy):
            plt.imshow(imgs[i])
            print("You clicked close enough!")

cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()