In [None]:
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torch.utils.data import DataLoader
import torchvision

from janus.datasets import Boyd2019, MultiCellDataset
from janus.networks import SiameseNet
from janus.viz import plot_cell, sample_imgs, tsne, umap


data_path = '../data/boyd_2019'
metadata_file = '../data/boyd_2019_PlateMap-KPP_MOA.xlsx'
mda231_path = '22_384_20X-hNA_D_F_C3_C5_20160031_2016.01.25.17.23.13_MDA231'
mda468_path = '22_384_20X-hNA_D_F_C3_C5_20160032_2016.01.25.16.27.22_MDA468'

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def read_data(metadata, seed=1):
    
    results_path = '../results/boyd_2019'
    
    tr_1 = torch.load('{}/train_1_seed_{}.pkl'.format(results_path, seed))
    tr_2 = torch.load('{}/train_2_seed_{}.pkl'.format(results_path, seed))    
    tr_data = MultiCellDataset(tr_1, tr_2, metadata)
    
    te_1 = torch.load('{}/test_1_seed_{}.pkl'.format(results_path, seed))
    te_2 = torch.load('{}/test_2_seed_{}.pkl'.format(results_path, seed))    
    te_data = MultiCellDataset(te_1, te_2, metadata)
    
    return tr_data, te_data


metadata = Boyd2019.read_metadata(metadata_file)
metadata = metadata.loc[metadata.moa.isin(['Neutral', 'PKC Inhibitor'])]

avg_mda231, std_mda231 = Boyd2019.load_parameters(
    join(data_path, mda231_path, 'norm_params.pkl'), None, False)
avg_mda468, std_mda468 = Boyd2019.load_parameters(
    join(data_path, mda468_path, 'norm_params.pkl'), None, False)

tr_data, te_data = read_data(metadata)

for i in np.random.randint(1, 1000, 4):
    plot_cell(tr_data.dataset_1[i][0])

In [None]:
avg_mda468

# Data generator

In [None]:
def imshow(img, text=None, should_save=False):

    fig, ax = plt.subplots(figsize=(15, 5))

    npimg = img.numpy()

    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})

    ax.imshow(np.transpose(npimg, (1, 2, 0)))
    ax.axis('off')
    plt.show()

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

vis_dataloader = DataLoader(tr_data,
                            shuffle=True,
                            num_workers=0,
                            batch_size=8)
dataiter = iter(vis_dataloader)

example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[2]),0)
print(example_batch[4].numpy())
imshow(torchvision.utils.make_grid(concatenated))

# Testing

In [None]:
from janus.viz import embed_matrix, plot_tiles
import matplotlib.patches as patches

# load saved net
net = SiameseNet().to(device)
net.load_state_dict(torch.load('../results/boyd_2019/sn_dropout_0.5_margin_1.0_seed_1_epoch_100.torch', 
                               map_location=torch.device(device)))
net = net.eval()

imgs, embeddings, moas, cell_lines = sample_imgs(net, tr_data)

umap(embeddings, moas)

x_emb = embed_matrix(embeddings, 'umap')
canvas, img_idx_dict = plot_tiles(imgs, x_emb, 20, pad=1)

fig, ax = plt.subplots(figsize=(25, 25))
ax.imshow(canvas)
ax.axis('off')

palette = list(sns.color_palette().as_hex())

for img_key in img_idx_dict.keys():
    xmin, xmax, ymin, ymax = img_idx_dict[img_key]
    cls = moas[img_key]
    colour = palette[0] if cls == "Neutral" else palette[1]
    # Create a Rectangle patch
    line_width = 3
    rect = patches.Rectangle((xmin+line_width, ymin+line_width), 
                             xmax-xmin-2*line_width+.5, 
                             ymax-ymin-2*line_width+.5,
                             linewidth=line_width, 
                             edgecolor=colour, 
                             facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)

umap(embeddings, moas)

In [None]:
x = torch.mean(imgs, dim=0)

fig, axes = plt.subplots(figsize=(8, 8), ncols=3)

axes[0].imshow(std_mda468[0])
axes[1].imshow(std_mda468[1])
axes[2].imshow(std_mda468[2])

In [None]:

# perc to 99
# replace max by clip
def plot_cell(crop):

    def rescale(channel):
        top = np.percentile(channel, 99)
        bot = torch.min(channel)
        return (channel - bot) / (top - bot)
    
    dapi = rescale(crop[0,...])
    cy5 = rescale(crop[1,...])
    cy3 = rescale(crop[2,...])
    
    fig, axes = plt.subplots(figsize=(5, 5), ncols=2, nrows=2)
    axes[0][0].imshow(dapi, cmap='Blues')
    axes[0][1].imshow(cy5, cmap='Reds')
    axes[1][0].imshow(cy3, cmap='Greens')
    axes[1][1].imshow(np.dstack((cy5[..., None], cy3[..., None], dapi[..., None])))

x = (imgs[22] * std_mda231) + avg_mda231
x = x.to(torch.uint8)
plot_cell(imgs[22])

In [None]:
np.percentile(channel, 99, axis=(1,2))
torch.min(channel, dim=0)

In [None]:
channel = imgs[22]
top = np.percentile(channel, 99)
bot = channel.min(0)
(channel - bot) / (top - bot) 

In [None]:
channel.min(0)

In [None]:
imgs[22].shape
# imgs[22].permute((1, 2, 0)).shape

In [None]:
imgs[22][0,...]

In [None]:
x = (imgs[22] * std_mda231) + avg_mda231
x = x.to(torch.uint8)
plot_cell(x)

In [None]:
x.shape

In [None]:
np.max(x.numpy(), axis=(1, 2))

In [None]:
canvas, img_idx_dict = plot_tiles(imgs, x_emb, 20, pad=1)

fig, ax = plt.subplots(figsize=(25, 25))
ax.imshow(canvas)
ax.axis('off')

palette = list(sns.color_palette().as_hex())

for img_key in img_idx_dict.keys():
    xmin, xmax, ymin, ymax = img_idx_dict[img_key]
    cls = moas[img_key]
    colour = palette[0] if cls == "Neutral" else palette[1]
    # Create a Rectangle patch
    line_width = 3
    rect = patches.Rectangle((xmin+line_width, ymin+line_width), 
                             xmax-xmin-2*line_width+.5, 
                             ymax-ymin-2*line_width+.5,
                             linewidth=line_width, 
                             edgecolor=colour, 
                             facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)

umap(embeddings, moas)

In [None]:
moas[img_key]