In [None]:
%load_ext rpy2.ipython
from functools import lru_cache

import numpy as np
import torch
import torchvision.models as models
from torchvision import transforms

from janus.datasets import Boyd2019, MultiCellDataset
from janus.networks import SiameseNet
from janus.transforms import RandomRot90, RGB
from src.viz import embed_matrix, sample_imgs, xy_plot


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

fs = models.vgg19(pretrained=True)
# fs = None

metadata_file = '../data/boyd_2019_PlateMap-KPP_MOA.xlsx'
results_path = '../results/boyd_2019/'
labels = ['Neutral', 'PKC Inhibitor']
# labels = ['Neutral', 'EGF Receptor Kinase Inhibitor', 'Cysteine Protease Inhibitor', 
#           'PKC Inhibitor', 'Tyrosine Kinase Inhibitor', 'Protein Tyrosine Phosphatase Inhibitor']

if len(labels) > 2:
    results_path = results_path[:-1] + '_multiclass/'
if fs:
    results_path = results_path[:-1] + '_vgg_finetune/'

In [None]:
%%R -i results_path

library(tidyverse)

sn <- lapply(list.files(path = results_path, pattern = '*tsv'), function(f) {
    
    params <- strsplit(f, '_') %>% unlist
    
    read_tsv(paste0(results_path, f), col_types = 'dd') %>%
        mutate(dropout = as.numeric(params[3]),
               margin = as.numeric(params[5]),
               seed = strsplit(params[7], '.', fixed = TRUE) %>% unlist %>% head(1) %>% as.numeric,
               i = 1:n()) %>%
        filter(i %% 1000 == 0)
    
}) %>% bind_rows %>%
    pivot_longer(ends_with('loss'), names_to = 'dataset', values_to = 'loss') %>%
    mutate(dataset = gsub('_loss', '', dataset))

In [None]:
%%R -w 8 -h 8 --units in

sn %>%
    group_by(dropout, margin, dataset, i) %>%
    summarize(loss = mean(loss)) %>%
    ggplot(aes(x=i, y=log10(loss), color=dataset)) +
        geom_step() +
        facet_wrap(dropout ~ margin, scales = 'free') +
        labs(x = 'Iteration', y='LogLoss', color='Dataset') +
        scale_x_continuous(breaks = c(1, 5000, 15000)) +
        theme(text = element_text(size=18),
              legend.position = 'bottom')

In [None]:
@lru_cache(maxsize=None)
def get_embeddings(dropout, margin, seed, dataset='test'):

    torch_file = 'sn_dropout_%s_margin_%s_seed_%s_epoch_100.torch' %\
        (dropout, margin, seed)


    # load saved net
    if fs:
        net = SiameseNet(feature_extractor=fs).to(device)
        tr = transforms.Compose([transforms.RandomHorizontalFlip(),
                                 transforms.RandomVerticalFlip(),
                                 RandomRot90(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                      std=[0.229, 0.224, 0.225])])
    else:
        net = SiameseNet().to(device)
        tr = transforms.Compose([transforms.RandomHorizontalFlip(),
                                 transforms.RandomVerticalFlip(),
                                 RandomRot90()])
        
    net.load_state_dict(torch.load(results_path + torch_file,
                                   map_location=torch.device(device)))
    net = net.eval()

    ds1 = torch.load('%s/%s_1_seed_%s.pkl' % (results_path, dataset, seed))
    ds2 = torch.load('%s/%s_2_seed_%s.pkl' % (results_path, dataset, seed))
    metadata = Boyd2019.read_metadata('../data/boyd_2019_PlateMap-KPP_MOA.xlsx')
    metadata = metadata.loc[metadata.moa.isin(labels)]
    
    _, embeddings, moas, cell_line = \
        sample_imgs(net, MultiCellDataset(ds1, ds2, metadata, transform=tr), iters=20)

    return embed_matrix(embeddings, 'umap'), moas, cell_line

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

def plot_embeddings_d(splits, dropouts, dataset, margin=1.0):
    
    fig, ax = plt.subplots(len(splits), len(dropouts), figsize = (15, 14))
    np.random.seed(1)

    for i in splits:
        for j,d in enumerate(dropouts):

            x, moa, cell_line = get_embeddings(d, margin, i+1, dataset)
            sns.scatterplot(x=x[:, 0], y=x[:, 1], style=cell_line, hue=moa, ax=ax[i,j])
            ax[i,j].get_legend().remove()

            if i == 0:
                ax[i,j].title.set_text('Dropout %s' % d)            

    handles, labels = ax[i,j].get_legend_handles_labels()
    fig.suptitle(dataset.capitalize(), fontsize=25, x=.15, y=.93)
    fig.legend(handles, labels, loc=(.5, .91), ncol = 2)
    
plot_embeddings_d([x for x in range(3)], [0.1, 0.5], 'test')

In [None]:
def plot_embeddings_m(splits, margins, dataset, dropout=.1):
    
    fig, ax = plt.subplots(len(splits), len(margins), figsize = (15, 14))
    np.random.seed(1)

    for i in splits:
        for j,m in enumerate(margins):

            x, moa, cell_line = get_embeddings(dropout, m, i+1, dataset)
            sns.scatterplot(x=x[:, 0], y=x[:, 1], style=cell_line, hue=moa, ax=ax[i,j])
            ax[i,j].get_legend().remove()

            if i == 0:
                ax[i,j].title.set_text('Margin %s' % m)            

    handles, labels = ax[i,j].get_legend_handles_labels()
    fig.suptitle(dataset.capitalize(), fontsize=25, x=.15, y=.93)
    fig.legend(handles, labels, loc=(.5, .91), ncol = 2)
    
plot_embeddings_m([x for x in range(3)], ['0.1', '1.0'], 'test')

In [None]:
plot_embeddings_m([x for x in range(3)], ['0.1', '1.0'], 'train')