In [1]:
# path to sequence_vaes directory
# abspath = "/content/drive/MyDrive/bedford_lab/code/seq_vaes"
abspath = ".."

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import sys
import json
from treetime.utils import datetime_from_numeric
from collections.abc import Iterable

# path to sequence_vaes directory and pip install
# %cd "/content/drive/MyDrive/bedford_lab/code/seq_vaes"
# !pip install -r requirements.txt

In [3]:
sys.path.append(abspath)
from models import DNADataset, ALPHABET, SEQ_LENGTH, LATENT_DIM, VAE
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE as tsne
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd

import bedford_code.models_bedford as bedford
import utils

In [4]:
BATCH_SIZE = 64

# "data" directory is generated as shown in README.md file
dataset = DNADataset(f"{abspath}/data/training_spike.fasta")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [6]:
input_dim = len(ALPHABET) * SEQ_LENGTH
# input_dim = 29903 * 5
# input_dim = 29903

#STANDARD
vae_model = VAE(input_dim=input_dim, latent_dim=50, non_linear_activation=nn.Softplus(beta=1.0)).to(DEVICE)
vae_model.load_state_dict(torch.load("./model_saves/standard_VAE_model_BEST.pth", weights_only=True, map_location=DEVICE))


vae_model.eval()

  return self.fget.__get__(instance, owner)()


VAE(
  (non_linear_activation): Softplus(beta=1.0, threshold=20)
  (encoder): Encoder(
    (non_linear_activation): Softplus(beta=1.0, threshold=20)
    (encode): Sequential(
      (0): Linear(in_features=19110, out_features=512, bias=True)
      (1): Softplus(beta=1.0, threshold=20)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): Softplus(beta=1.0, threshold=20)
    )
    (fc_mean): Linear(in_features=256, out_features=50, bias=True)
    (fc_logvar): Linear(in_features=256, out_features=50, bias=True)
  )
  (decoder): Decoder(
    (non_linear_activation): Softplus(beta=1.0, threshold=20)
    (decode): Sequential(
      (0): Linear(in_features=50, out_features=256, bias=True)
      (1): Softplus(beta=1.0, threshold=20)
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): Softplus(beta=1.0, threshold=20)
      (4): Linear(in_features=512, out_features=19110, bias=True)
    )
  )
)

In [7]:
import utils
dset = ["training", "valid", "test"]
dset = dset[0]
abspath = ".."


# LOAD DATA
data_keys, data_dict = utils.get_data_dict(dset, abspath)
print(data_keys)
new_dataset = data_dict["new_dataset"]
vals = data_dict["vals"]
metadata = data_dict["metadata"]
clade_labels = data_dict["clade_labels"]
collection_dates = data_dict["collection_dates"]
indexes = data_dict["indexes"]
pairs = data_dict["pairs"]
get_parents_dict = data_dict["get_parents_dict"]

collection_dates
 [21, 54, 67, 86, 108, 140, 154, 187, 206, 220, 240, 257, 264, 275, 282, 288, 291, 308, 333, 345, 358, 369, 377, 386, 397, 408, 416, 426, 435, 461, 798, 1373, 2807, 4105, 5241, 6264, 6368]

unique clusters
 ['19A' '21K (BA.1)' '21L (BA.2)' '21M (Omicron)' '22A (BA.4)' '22B (BA.5)'
 '22C (BA.2.12.1)' '22D (BA.2.75)' '22E (BQ.1)' '22F (XBB)'
 '23A (XBB.1.5)' '23B (XBB.1.16)' '23C (CH.1.1)' '23D (XBB.1.9)'
 '23E (XBB.2.3)' '23F (EG.5.1)' '23G (XBB.1.5.70)' '23H (HK.3)'
 '23I (BA.2.86)' '24A (JN.1)' '24D (XDV.1)' '24E (KP.3.1.1)' '24F (XEC)'
 '24G (KP.2.3)' '24H (LF.7)' '24I (MV.1)' '25A (LP.8.1)' '25B (NB.1.8.1)'
 '25C (XFG)']

sanity check - len(new_vals), len(vals)
 6368   6368
['new_dataset', 'vals', 'metadata', 'clade_labels', 'collection_dates', 'indexes', 'pairs', 'get_parents_dict']


In [75]:
vae_model = vae_model.requires_grad_(False)

X = torch.tensor(new_dataset)
X = X.view(X.size(0), -1).to(DEVICE)
X = utils.mask_gaps(X,zero_idx=4)

Z_mean, _ = vae_model.encoder.forward(X)

z = Z_mean[0,:]
r_z = vae_model.decoder.forward(z[None,:])
print(r_z.shape)

torch.Size([1, 3822, 5])


In [78]:
print(vae_model.decoder)

Decoder(
  (non_linear_activation): Softplus(beta=1.0, threshold=20)
  (decode): Sequential(
    (0): Linear(in_features=50, out_features=256, bias=True)
    (1): Softplus(beta=1.0, threshold=20)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): Softplus(beta=1.0, threshold=20)
    (4): Linear(in_features=512, out_features=19110, bias=True)
  )
)


In [80]:
grads = torch.autograd.functional.jacobian(vae_model.decoder.decode, z)
print(grads.shape)

torch.Size([19110, 50])


In [82]:
metric = torch.matmul(grads.T,grads)

In [83]:
print(metric)

tensor([[ 2.8120e+01,  4.6728e+00,  6.8707e+02,  ..., -5.6763e+00,
          1.0328e+01,  1.1984e+01],
        [ 4.6728e+00,  2.8615e+00,  1.0975e+02,  ..., -1.9634e-01,
          4.9572e+00,  1.9510e-01],
        [ 6.8707e+02,  1.0975e+02,  1.7877e+04,  ..., -1.6073e+02,
          1.8618e+02,  3.6894e+02],
        ...,
        [-5.6763e+00, -1.9634e-01, -1.6073e+02,  ...,  2.2123e+00,
          5.6787e-01, -5.2093e+00],
        [ 1.0328e+01,  4.9572e+00,  1.8618e+02,  ...,  5.6787e-01,
          1.8152e+01, -6.7884e+00],
        [ 1.1984e+01,  1.9510e-01,  3.6894e+02,  ..., -5.2093e+00,
         -6.7884e+00,  1.6493e+01]], device='cuda:0')


In [84]:
print(metric.shape)

torch.Size([50, 50])


In [87]:
n1 = 10
n2 = 5

print(metric[n1,n2])
print(metric[n2,n1])

tensor(-8.0744, device='cuda:0')
tensor(-8.0744, device='cuda:0')


In [90]:
print(torch.all(torch.real(torch.linalg.eigvals(metric)) > 0))

tensor(True, device='cuda:0')
