In [None]:
import torch
import torch.nn.functional as F
import torchvision.models as models
from sklearn.manifold import TSNE
from umap import UMAP

from janus.datasets import Boyd2019
from src.viz import xy_plot, norm_crop_for_vis

data_path = '../data/boyd_2019'
metadata_file = '../data/boyd_2019_PlateMap-KPP_MOA.xlsx'

metadata = Boyd2019.read_metadata(metadata_file)

# filter by 2 moas and make train test
metadata = metadata.loc[metadata.moa.isin(['Neutral', 'EGF Receptor Kinase Inhibitor'])]
train_metadata = metadata.sample(frac=0.7)
test_metadata = metadata.drop(train_metadata.index)

boyd2019 = Boyd2019(data_path, train_metadata, padding=64, scale=0.5)

mda231 = torch.cat([x[0][None] for x in boyd2019.dataset_1], axis=0)
mda468 = torch.cat([x[0][None] for x in boyd2019.dataset_2], axis=0)

print(mda231.shape, mda468.shape)

## Raw pixels

In [None]:
nb_samples = 2000
labels = nb_samples // 2 * ['mda231'] + nb_samples // 2 * ['mda468']

idx = torch.randint(0, mda231.shape[0], (nb_samples // 2,))
mda231_samples = mda231[idx]

idx = torch.randint(0, mda468.shape[0], (nb_samples // 2,))
mda468_samples = mda468[idx]

all_samples = torch.cat([mda231_samples, mda468_samples], axis=0)
print(torch.min(all_samples), torch.max(all_samples))

In [None]:
x_emb = TSNE().fit_transform(all_samples.reshape(nb_samples, -1))
xy_plot(x_emb, labels, 'TSNE')

## Raw pixels - unnormalised

In [None]:
norm_mda231 = [norm_crop_for_vis(x).permute((2, 0, 1))[None] for x in mda231_samples]
norm_mda468 = [norm_crop_for_vis(x).permute((2, 0, 1))[None] for x in mda468_samples]

all_samples_un = torch.cat([torch.cat(norm_mda231), torch.cat(norm_mda468)], axis=0)

print(torch.min(all_samples_un), torch.max(all_samples_un))

In [None]:
x_emb = TSNE().fit_transform(all_samples_un.reshape(nb_samples, -1))
xy_plot(x_emb, labels, 'TSNE')

## VGG features

In [None]:
vgg16 = models.vgg16(pretrained=True)
vgg16 = vgg16.eval()

def vgg_features(inputs):

    x = vgg16.features(inputs)
    x = vgg16.avgpool(x)
    x = torch.flatten(x, 1)
    x = vgg16.classifier[:5](x)

    return x

vgg_mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)

In [None]:
cnn_codes = torch.empty((0,))

for i in range(nb_samples):
    inputs = all_samples[i:i+1]
    inputs = (inputs - vgg_mean) / vgg_std
    inputs = F.interpolate(inputs, size=224, mode='bilinear')
    cnn_codes = torch.cat([cnn_codes, vgg_features(inputs).detach()])

In [None]:
x_emb = TSNE().fit_transform(cnn_codes.reshape(nb_samples, -1))
xy_plot(x_emb, labels, 'TSNE')

## VGG features - unnormalised

In [None]:
cnn_codes_un = torch.empty((0,))

for i in range(nb_samples):
    inputs = all_samples_un[i:i+1]
    inputs = (inputs - vgg_mean) / vgg_std
    inputs = F.interpolate(inputs, size=224, mode='bilinear')
    cnn_codes_un = torch.cat([cnn_codes_un, vgg_features(inputs).detach()])

In [None]:
x_emb = TSNE().fit_transform(cnn_codes_un.reshape(nb_samples, -1))
xy_plot(x_emb, labels, 'TSNE')