In [None]:
from os.path import join
from socket import gethostname

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from janus.datasets import Boyd2019
from janus.losses import ContrastiveLoss
from janus.networks import SiameseNet
from src.viz import plot_cell, tsne, umap


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

data_path = '../data/boyd_2019'
results_path = '../results/boyd_2019'
metadata_file = '../data/boyd_2019_PlateMap-KPP_MOA.xlsx'

metadata = Boyd2019.read_metadata(metadata_file)

# filter by 2 moas and make train test
metadata = metadata.loc[metadata.moa.isin(['Neutral', 'EGF Receptor Kinase Inhibitor'])]
train_metadata = metadata.sample(frac=0.7)
test_metadata = metadata.drop(train_metadata.index)

boyd2019 = Boyd2019(data_path, train_metadata, padding=64, scale=0.5)

for i in np.random.randint(1, 1000, 4):
    plot_cell(boyd2019.dataset_1[i][0])

In [None]:
boyd2019.metadata.head(10)

In [None]:
print('Dataset 1: %d cells' % len(boyd2019.dataset_1))
print('Dataset 2: %d cells' % len(boyd2019.dataset_2))

# Data generator

In [None]:
def imshow(img, text=None, should_save=False):

    fig, ax = plt.subplots(figsize=(15, 5))

    npimg = img.numpy()

    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})

    ax.imshow(np.transpose(npimg, (1, 2, 0)))
    ax.axis('off')
    plt.show()

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

vis_dataloader = DataLoader(boyd2019,
                            shuffle=True,
                            num_workers=8,
                            batch_size=8)
dataiter = iter(vis_dataloader)

example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0], example_batch[2]),0)
print(example_batch[4].numpy())
imshow(torchvision.utils.make_grid(concatenated))

# Testing

In [None]:
# load saved net
model = 'sn_dropout_0.5_margin_1.0_seed_5_epoch_100.torch'

net = SiameseNet().to(device)
net.load_state_dict(torch.load(join(results_path, model),
                               map_location=torch.device(device)))
net = net.eval()

In [None]:
# get test set: different wells than training
# test_metadata = pd.DataFrame({'well': ['C01', 'D02'], 'moa': [1, 2]})
test_boyd2019 = Boyd2019(data_path, test_metadata, padding=64, scale=0.5)

test_dataloader = DataLoader(test_boyd2019,
                             shuffle=True,
                             num_workers=8,
                             batch_size=64)

In [None]:
from src.viz import norm_crop_for_vis

# sample test set
embedding = np.empty((0, 256))
moas = []
cell_line = np.empty((0,))

all_imgs = []

for i, data in enumerate(test_dataloader, 0):
    img0, moa0, img1, moa1, _ = data
    img0, img1 = img0.to(device), img1.to(device)
    output1, output2 = net(img0, img1)

    embedding = np.concatenate((embedding, output1.detach().numpy(), output2.detach().numpy()))
    cell_line = np.concatenate((cell_line,
                                np.repeat('mda468', output1.shape[0]),
                                np.repeat('mda231', output2.shape[0])))
    moas.extend(moa0)
    moas.extend(moa1)

    normed_img0 = torch.cat([norm_crop_for_vis(img)[None] for img in img0], axis=0)
    normed_img1 = torch.cat([norm_crop_for_vis(img)[None] for img in img1], axis=0)

    all_imgs.append(normed_img0)
    all_imgs.append(normed_img1)

    if i == 100:
        break

all_imgs = torch.cat(all_imgs).cpu()

In [None]:
from umap import UMAP
from src.viz import xy_plot

x_emb = UMAP().fit_transform(embedding)
xy_plot(x_emb, cell_line, 'UMAP')

In [None]:
from src.viz import plot_tiles
import matplotlib.patches as patches

canvas, img_idx_dict = plot_tiles(all_imgs, x_emb, 30, pad=1)

fig, ax = plt.subplots(figsize=(15, 15))
ax.imshow(canvas)
ax.axis('off')

palette = list(sns.color_palette().as_hex())

# for img_key in img_idx_dict.keys():
#     xmin, xmax, ymin, ymax = img_idx_dict[img_key]
#     cls = moas[img_key].item()
#     colour = palette[0] if cls == 1 else palette[1]
#     # Create a Rectangle patch
#     line_width = 3
#     rect = patches.Rectangle((xmin+line_width, ymin+line_width), xmax-xmin-2*line_width, ymax-ymin-2*line_width,
#                              linewidth=line_width, edgecolor=colour, facecolor='none')
#     # Add the patch to the Axes
#     ax.add_patch(rect)

In [None]:
umap(embedding, moas)

# MoA Prediction

In [None]:
from src.lococv import LocoCV
from src.viz import plot_confusion_matrix

loco = LocoCV(test_boyd2019, net)

df_profiles = loco.construct_profiles()
confusion = loco.lococv(df_profiles)

ae_acc = np.trace(confusion) / np.sum(confusion)
print('Accuracy: %.04f' % ae_acc)

plot_confusion_matrix(confusion, test_boyd2019.metadata.moa.unique())