In [25]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer, LightningDataModule
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
import pickle as pkl

import numpy as np

In [42]:
class VAE(LightningModule):
    def __init__(self, latent_dim=10, input_height=2, kernel_size=10, width=56):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = nn.Sequential(
          nn.Conv2d(2,20, (kernel_size,2)),
          nn.ReLU(),
          nn.Conv2d(20,10, (kernel_size,2)),
          nn.ReLU(),
          nn.Conv2d(20,5, (kernel_size,2)),
          nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
          nn.Conv2d(1,20, (kernel_size,2)),
          nn.ReLU(),
          nn.Conv2d(20,10, (kernel_size,2)),
          nn.ReLU(),
          nn.Conv2d(10,2, (kernel_size,2)),
          nn.ReLU()
        )
        
        # distribution parameters
        self.fc_mu = nn.Linear(width*5, latent_dim)
        self.fc_var = nn.Linear(width*5, latent_dim)
        
        # back to proper width 
        self.to_dec = nn.Linear(latent_dim, width)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        
        print(batch.shape)
        
        x = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        zdec = self.to_dec(z)
        x_hat = self.decoder(zdec)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(),
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        })

        return elbo

In [43]:
class MyDataModule(LightningDataModule):

    def __init__(self, X):
        super().__init__()
        
        
        self.train, X_test = train_test_split(X, test_size=0.3, random_state=0)
        self.test, self.val = train_test_split(X_test, test_size=0.5, random_state=0)
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64)

In [44]:
with open('../data/H190923.pkl', 'rb') as f: 
    data = pkl.load(f)
data.head()

Unnamed: 0,track_index,erk_time(h)=0.0,erk_time(h)=0.2,erk_time(h)=0.5,erk_time(h)=0.8,erk_time(h)=1.0,erk_time(h)=1.2,erk_time(h)=1.5,erk_time(h)=1.8,erk_time(h)=2.0,...,akt_time(h)=13.5,akt_time(h)=13.8,akt_site,akt_median,conc_nm,inhibitor,cell_line,control_OD,optical_density,cell_viab
0,78_1000101697_k562_erk_akt__untreated,0.666206,0.583154,0.464076,0.401721,0.574265,0.628919,0.533323,0.63951,0.808772,...,0.49822,0.523217,78,0.496346,0.0,none,k562,,,1.0
1,77_1000098381_k562_erk_akt__untreated,0.646522,0.419623,0.63642,0.608784,0.59337,0.594475,0.764531,0.781966,0.703256,...,0.630918,0.623906,77,0.503596,0.0,none,k562,,,1.0
2,4_1000007026_k562_erk_akt__untreated,0.731499,0.644271,0.565533,0.405811,0.538459,0.716891,0.771548,0.708265,0.789968,...,0.835404,0.62695,4,0.64291,0.0,none,k562,,,1.0
3,2_1000002957_k562_erk_akt__untreated,0.860727,0.69016,0.633814,0.886553,0.915976,0.856283,0.759563,0.927312,0.948254,...,0.74985,0.384942,2,0.541348,0.0,none,k562,,,1.0
4,82_1000111191_k562_erk_akt__untreated,0.547424,0.521789,0.448863,0.573173,0.5043,0.525133,0.493384,0.638652,0.621023,...,0.684539,0.720796,82,0.609912,0.0,none,k562,,,1.0


In [45]:
df = data
X = np.concatenate((df[df.attrs['clover_selector']].values.reshape(-1, len(df.attrs['clover_selector']), 1), 
                        df[df.attrs['mscarlet_selector']].values.reshape(-1, len(df.attrs['clover_selector']), 1)), 
                       axis=2)

X.shape

(816, 56, 2)

In [46]:
mydata = MyDataModule(X)

model = VAE()
trainer = Trainer()
trainer.fit(model, datamodule=mydata)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 6.8 K 
1 | decoder | Sequential | 4.8 K 
2 | fc_mu   | Linear     | 2.8 K 
3 | fc_var  | Linear     | 2.8 K 
4 | to_dec  | Linear     | 616   
---------------------------------------
17.9 K    Trainable params
0         Non-trainable params
17.9 K    Total params


Training: |          | 0/? [00:00<?, ?it/s]

torch.Size([64, 56, 2])


RuntimeError: Expected 4-dimensional input for 4-dimensional weight [20, 2, 10, 2], but got 3-dimensional input of size [64, 56, 2] instead