In [1]:
import os 
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from dataloader import *
from evaluation import *
import warnings
import pyro

### Create dataset

Here, we load the data. Supply the path csv files containing:

- raw counts
- cell centroids

Optionally, metadata can also be supplied, and a `label` corresponding to the cell type in the metadata can be provided for cell type labels to be used in training the model.

In [2]:
# MERFISH 
# counts_path = 'data/merfish/hypo_ani1_counts.csv'
# centroids_path = 'data/merfish/hypo_ani1_cellcentroids.csv'
# metadata_path = 'data/merfish/hypo_ani1_metadata.csv'

# ST
# counts_path = 'data/ST/ob_counts.csv'
# centroids_path = 'data/ST/ob_centroids.csv'
# metadata_path = 'data/ST/ob_metadata.csv'

# ISH
counts_path = 'data/ISH/drosophila_counts.csv'
centroids_path = 'data/ISH/drosophila_centroids.csv'
metadata_path = 'data/ISH/drosophila_metadata.csv'

In [3]:
dset = SpatialDataset(counts_path = counts_path, 
                      centroids_path = centroids_path,
                      metadata_path = metadata_path, 
                      label = 'Cell_class', 
                      axes = ['Centroid_X', 'Centroid_Z'])

### Data loader

In [4]:
train_loader, val_loader, train_idx, val_idx = data_loader(
    dset, batch_size=1, train_split = 1, shuffle=False)

In [5]:
for idx, sample in enumerate(train_loader):
    x = sample['cell_counts']
    y = sample['neighbor_counts']
    if idx == 0:
        break

### Model

# Full training on server 

import os
import time
import pickle
import yaml
import torch
import pyro
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from pyro.optim import Adam, ClippedAdam
import numpy as np
from dataloader import *
from shutil import copyfile

LATENT_DIM=10
HIDDEN_DIM1=128
HIDDEN_DIM2=128
HIDDEN_DIM3=128
USE_CUDA=False
LEARNING_RATE=1.0e-3
NUM_EPOCHS=5
TEST_FREQUENCY=2
n_input = dset.n_features

from models.nbcvae import *
print('Training ZINB CVAE')
vae = CVAE(
    n_input=n_input,
    z_dim=LATENT_DIM,
    hidden_dim1=HIDDEN_DIM1,
    hidden_dim2=HIDDEN_DIM2,
    hidden_dim3=HIDDEN_DIM3,
    use_cuda=USE_CUDA
)

adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)
loss = Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=loss)

train_elbo = []
val_elbo = []

for epoch in range(NUM_EPOCHS):

        total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
        train_elbo.append(-total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch > 0 and epoch % TEST_FREQUENCY == 0:
            # report test diagnostics
            total_epoch_loss_test = evaluate(svi, val_loader, use_cuda=USE_CUDA)
            val_elbo.append(-total_epoch_loss_test)
            print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))


### Load trained model

In [6]:
import yaml
CONFIG_FILE = 'checkpoints/NBCVAE_20211108-174537/config.yaml'
with open(CONFIG_FILE) as file:
    config = yaml.safe_load(file)
    
model_params = config['model_params']
MODEL_NAME = model_params['model_name']
LATENT_DIM = model_params['latent_dim']
HIDDEN_DIM1 = model_params['hidden_dim1']
HIDDEN_DIM2 = model_params['hidden_dim2']
HIDDEN_DIM3 = model_params['hidden_dim3']

# Data
data_params = config['data_params']
COUNTS_PATH = data_params['counts_path']
CENTROIDS_PATH = data_params['centroids_path']
METADATA_PATH = data_params['metadata_path']
N_NEIGHBORS = data_params['n_neighbors']

# Experiment parameters
exp_params = config['exp_params']
BATCH_SIZE = exp_params['batch_size']
TRAIN_SPLIT = exp_params['train_split']
LEARNING_RATE = exp_params['learning_rate']
USE_CUDA = False
NUM_EPOCHS = exp_params['num_epochs']
TEST_FREQUENCY = exp_params['test_frequency']
SAVE_FREQUENCY = exp_params['save_frequency']
SAVE_DIR = exp_params['save_dir']

n_input = dset.n_features
n_class = dset.n_class

if MODEL_NAME == 'ZINBVAE':
    from models.vae import *
    print('Training ZINB VAE')
    vae = ZINBVAE(
        n_input=n_input,
        z_dim=LATENT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2,
        hidden_dim3=HIDDEN_DIM3,
        use_cuda=USE_CUDA
    )

elif MODEL_NAME == 'NBVAE':
    from models.vae import *
    print('Training NB VAE')
    vae = NBVAE(
        n_input=n_input,
        z_dim=LATENT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2,
        hidden_dim3=HIDDEN_DIM3,
        use_cuda=USE_CUDA
    )

# conditional vae
elif MODEL_NAME == 'CVAE':
    from models.cvae import *
    print('Training ZINB CVAE')
    vae = CVAE(
        n_input=n_input,
        z_dim=LATENT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2,
        hidden_dim3=HIDDEN_DIM3,
        use_cuda=USE_CUDA
    )

elif MODEL_NAME == 'NBCVAE':
    from models.nbcvae import *
    print('Training NB CVAE')
    vae = CVAE(
        n_input=n_input,
        z_dim=LATENT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2,
        hidden_dim3=HIDDEN_DIM3,
        use_cuda=USE_CUDA
    )
elif MODEL_NAME == 'LabelVAE':
    from models.labelvae import *
    print('Training LabelVAE')
    vae = LabelVAE(
        n_genes=n_input,
        n_class=n_class,
        z_dim=LATENT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2,
        hidden_dim3=HIDDEN_DIM3,
        use_cuda=USE_CUDA
    )

Training NB CVAE


### Load model

In [9]:
PATH = 'checkpoints/NBCVAE_20211108-174537/model_epochs100.pt'
vae.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
vae.eval()

CVAE(
  (prior): Encoder(
    (fc1): Linear(in_features=84, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=128, bias=True)
    (mean_encoder): Linear(in_features=128, out_features=10, bias=True)
    (var_encoder): Linear(in_features=128, out_features=10, bias=True)
    (relu): ReLU()
  )
  (generation): Decoder(
    (fc1): Linear(in_features=10, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=128, bias=True)
    (rate): Linear(in_features=128, out_features=84, bias=True)
    (relu): ReLU()
    (softmax): Softmax(dim=-1)
  )
  (recognition): Encoder(
    (fc1): Linear(in_features=84, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=128, bias=True)
    (mean_encoder): Linear(in_features=128, out_features=10, bias=

In [10]:
# Output latent and reconstruction for cvae
OUT_DIR = 'checkpoints/NBCVAE_20211108-174537'
result = nbcvae_evaluation(vae, train_loader, out_dir = OUT_DIR)

# Save model weights
weights = dict(torch.load(PATH, map_location=torch.device('cpu')))
index = 1
lweights = weights[list(weights.keys())[index]].numpy()
df = pd.DataFrame(lweights)
df.to_csv('checkpoints/CVAE_20211107-200530/layer1_weights.csv', index=False)

lat = get_latent(vae, 'LabelVAE', train_loader, 
                 file = 'checkpoints/LabelVAE_20211107-023946/latent.csv')
rec = get_recon(vae, 'LabelVAE', train_loader, distribution = 'ZINB',
                file = 'checkpoints/LabelVAE_20211107-023946/recon.csv')