In [None]:
%load_ext rpy2.ipython
from functools import lru_cache

import numpy as np
import torch

from janus.datasets import Boyd2019, MultiCellDataset
from janus.networks import SiameseNet
from janus.viz import embed_matrix, sample_imgs, xy_plot


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

results_path = '../results/boyd_2019/'
metadata_file = '../data/boyd_2019_PlateMap-KPP_MOA.xlsx'

In [None]:
%%R -i results_path

library(tidyverse)

sn <- lapply(list.files(path = results_path, pattern = '*tsv'), function(f) {
    
    params <- strsplit(f, '_') %>% unlist
    
    read_tsv(paste0(results_path, f), col_types = 'dd') %>%
        mutate(dropout = as.numeric(params[3]),
               margin = as.numeric(params[5]),
               seed = strsplit(params[7], '.', fixed = TRUE) %>% unlist %>% head(1) %>% as.numeric,
               i = 1:n()) %>%
        filter(i %% 1000 == 0)
    
}) %>% bind_rows %>%
    pivot_longer(ends_with('loss'), names_to = 'dataset', values_to = 'loss') %>%
    mutate(dataset = gsub('_loss', '', dataset))

In [None]:
%%R -w 10 -h 10 --units in

sn %>%
    group_by(dropout, margin, dataset, i) %>%
    summarize(loss = mean(loss)) %>%
    ggplot(aes(x=i, y=log10(loss), color=dataset)) +
        geom_step() +
        facet_wrap(dropout ~ margin, scales = 'free', ncol=4) +
        labs(x = 'Iteration', y='LogLoss', color='Dataset') +
        scale_x_continuous(breaks = c(1, 5000, 15000)) +
        theme(text = element_text(size=18),
              legend.position = 'bottom')

In [None]:
@lru_cache(maxsize=None)
def get_embeddings(dropout, margin, seed, dataset='test'):

    torch_file = 'sn_dropout_%s_margin_%s_seed_%s_epoch_100.torch' %\
        (dropout, margin, seed)

    # load saved net
    net = SiameseNet().to(device)
    net.load_state_dict(torch.load(results_path + torch_file,
                                   map_location=torch.device(device)))
    net = net.eval()

    ds1 = torch.load('%s/%s_1_seed_%s.pkl' % (results_path, dataset, seed))
    ds2 = torch.load('%s/%s_2_seed_%s.pkl' % (results_path, dataset, seed))
    metadata = Boyd2019.read_metadata('../data/boyd_2019_PlateMap-KPP_MOA.xlsx')
    metadata = metadata.loc[metadata.moa.isin(['Neutral', 'PKC Inhibitor'])]
    
    _, embeddings, moas, cell_line = \
        sample_imgs(net, MultiCellDataset(ds1, ds2, metadata))

    return embed_matrix(embeddings, 'umap'), moas, cell_line

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

dropouts = [0.05, 0.1, 0.25, 0.5]
fig, ax = plt.subplots(5, len(dropouts), figsize = (15, 14))

np.random.seed(1)

for i in range(5):
    for j,d in enumerate(dropouts):

        x, moa, cell_line = get_embeddings(d, 1.0, i+1, 'test')
        sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=moa, style=cell_line, ax=ax[i,j])
        ax[i,j].get_legend().remove()
        if i == 0:
            ax[i,j].title.set_text('Dropout %s' % d)

In [None]:
margins = ['0.001', '0.01', '0.1', '1.0']
fig, ax = plt.subplots(5, len(margins), figsize = (15, 14))

for i in range(5):
    for j,m in enumerate(margins):

        x, moa, cell_line = get_embeddings(0.5, m, i+1, 'test')
        sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=moa, style=cell_line, ax=ax[i,j])
        ax[i,j].get_legend().remove()
        if i == 0:
            ax[i,j].title.set_text('Margin %s' % m)

In [None]:
margins = ['0.001', '0.01', '0.1', '1.0']
fig, ax = plt.subplots(5, len(margins), figsize = (15, 14))

for i in range(5):
    for j,m in enumerate(margins):

        x, moa, cell_line = get_embeddings(0.5, m, i+1, 'train')
        sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=moa, style=cell_line, ax=ax[i,j])
        ax[i,j].get_legend().remove()
        if i == 0:
            ax[i,j].title.set_text('Margin %s' % m)

# Close ups

In [None]:
fig, ax = plt.subplots(1, 5, figsize = (50, 10))
ax[2].title.set_text('Margin 1.0')

np.random.seed(1)

for i in range(5):
    x, moa, cell_line = get_embeddings(0.5, 1.0, i+1, 'test')
    sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=cell_line, style=moa, ax=ax[i], s=100)
    ax[i].get_legend().remove()

In [None]:
fig, ax = plt.subplots(1, 5, figsize = (25, 5))
ax[2].title.set_text('Margin 1.0')

np.random.seed(1)

for i in range(5):
    x, moa, cell_line = get_embeddings(0.5, 1.0, i+1, 'train')
    sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=moa, style=cell_line, ax=ax[i], s=300)
    ax[i].get_legend().remove()