In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import sys
import json

## Standard VAE + Bedford VAE

In [None]:
abspath = "./VAE_standard"

In [None]:
sys.path.append(abspath)
from models import DNADataset, ALPHABET, SEQ_LENGTH, LATENT_DIM, VAE
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE as tsne
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd

import bedford_code.models_bedford as bedford
from treetime.utils import datetime_from_numeric
import pymc as pm
from collections.abc import Iterable
import altair as alt

In [None]:
BATCH_SIZE = 64
# "data" directory is generated as shown in README.md file
dataset = DNADataset(f"{abspath}/../data/training_spike.fasta")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
input_dim = len(ALPHABET) * SEQ_LENGTH
# input_dim = 29903 * 5
# input_dim = 29903

# BEDFORD
# vae_model = bedford.VAE(input_dim=len(bedford.ALPHABET) * bedford.SEQ_LENGTH, latent_dim=bedford.LATENT_DIM).to(DEVICE)
# vae_model.load_state_dict(torch.load(f"{abspath}/bedford_code/results_bedford/BEST_vae_ce_anneal.pth"))
#STANDARD
vae_model = VAE(input_dim=input_dim, latent_dim=50).to(DEVICE)
vae_model.load_state_dict(torch.load(f"{abspath}/model_saves/standard_VAE_model_BEST.pth", weights_only=True, map_location=DEVICE))


vae_model.eval()

In [None]:
dset = ["training", "valid", "test"]
dset = dset[0]
print(dset)
abspath = "."
dataset = DNADataset(f"{abspath}/data/{dset}_spike.fasta")
new_dataset = np.array([dataset[x][0].numpy() for x in range(len(dataset))])
vals = np.array([dataset[x][1] for x in range(len(dataset))])
# labeling
metadata = pd.read_csv(f"{abspath}/data/all_data/all_metadata.tsv", sep="\t")
clade_labels = [metadata.loc[metadata.name == vals[i], "clade_membership"].values[0] for i in range(len(vals))]
dates = [metadata.loc[metadata.name == vals[i], "date"].values[0] for i in range(len(vals))]
dates = [datetime_from_numeric(x) for x in dates]

In [None]:
def flatten(xs):
    for x in xs:
        if isinstance(x, Iterable) and not isinstance(x, (str, float)):
            yield from flatten(x)
        else:
            yield x

collection_dates = pd.DataFrame([[x] for i,x in enumerate(dates)], columns=["date"])
collection_dates.index = pd.to_datetime(collection_dates["date"])
collection_dates = collection_dates.groupby(pd.Grouper(freq='W'))
collection_dates = list(collection_dates.groups.values())
print(collection_dates)
collection_dates = [collection_dates[0]] + [collection_dates[i] - collection_dates[i-1] for i in range(len(collection_dates)-1, 0, -1)][::-1]
collection_dates = list(flatten([[i for j in range(x)] for i,x in enumerate(collection_dates)]))

In [None]:
good_clade_labels = []
for c in clade_labels:
    if len(metadata[metadata.clade_membership == c]) > 5:
        good_clade_labels.append(c)
print(set(good_clade_labels))

In [None]:
# print(set(clade_labels))

# clusters = np.sort(np.array(list(set(good_clade_labels))))
clusters = np.sort(np.array(list(set(clade_labels))))
print(clusters)
get_clade = lambda x: [True if elem == x else False for elem in clade_labels]

indexes = tuple([np.arange(len(clade_labels))[get_clade(x)] for x in clusters])

In [None]:
new_vals = []
for v in vals:
    if metadata.loc[metadata.name == v, "clade_membership"].values[0] in clusters:
        new_vals.append(v)

In [None]:
parents = pd.read_csv(f"{abspath}/data/all_data/all_branches.tsv", sep="\t")
node_dict = {x:i for i,x in enumerate(new_vals)}
pairs = []
for p,c in zip(parents["parent"], parents["child"]):
    i1 = node_dict.get(p, None)
    i2 = node_dict.get(c, None)

    if i1 and i2:
        pairs.append((i1,i2))

pairs = np.array(pairs)

In [None]:
cmap = plt.get_cmap("gist_ncar")
colors = [cmap(x) for x in np.arange(len(indexes)) / len(indexes)]

ranges = np.concatenate(indexes)

X = torch.tensor(new_dataset[ranges,:,:])
print("X shape")
print(new_dataset.shape)
print(X.shape)
# X = X.to(DEVICE)
X = X.view(X.size(0), -1).to(DEVICE)
pca = PCA(n_components=3, svd_solver="full")


recon = None
Z_mean = None
Z_embedded = None
scatterplot = None
with torch.no_grad():
    # STANDARD
    Z_mean, Z_logvar = vae_model.encoder.forward(X)
    recon = vae_model.decoder.forward(Z_mean)
    recon = recon.view(recon.shape[0], -1).cpu()
    Z_mean = Z_mean.cpu()
    Z_std = torch.exp(0.5 * Z_logvar).cpu()

    # BEDFORD
    # recon, Z_mean, Z_logvar = vae_model.forward(X)
    # recon = recon.cpu().numpy()
    # Z_mean = Z_mean.cpu().numpy()
    
    print("\nRecon shape")
    print(recon.shape)
    
    # Z_embedded = tsne(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(X = Z_mean)
    pca.fit(Z_mean)
    Z_embedded = pca.transform(Z_mean - torch.mean(Z_mean))
    variances = pca.explained_variance_ratio_
    tot = np.sum(variances)
    print(variances)
    print(f"total variance: {tot}")
    curr = 0
    

# %matplotlib widget
# plt.ion()
fig,ax = plt.subplots(1,1,figsize=(14,10),subplot_kw=dict(projection="3d"))

# RELATIONS
# for i,arr in enumerate(indexes):
#     ax.scatter(Z_embedded[arr,0], Z_embedded[arr,1], Z_embedded[arr, 2], label=clade_labels[arr[0]], alpha=0.6, color=colors[i], s=150)
#     curr += len(arr)
# l1 = 0.1
# l2 = 0
# ax.legend(bbox_to_anchor=(l1,l2,l1+1,l2+1))

# TIME 
scatterplot = ax.scatter(Z_embedded[:,0], Z_embedded[:,1], Z_embedded[:,2], c=collection_dates, cmap="viridis", s=150)
fig.colorbar(scatterplot, ax=ax, shrink=0.5)

ax.set_title("PCA visualization of (standard) VAE latent space")
for p in pairs:
    ax.plot(Z_embedded[p,0], Z_embedded[p,1], Z_embedded[p,2], color="gray", alpha=0.5)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
plt.show()

# plotting_dict = pd.DataFrame([(x1,x2,x3,c) for (x1,x2,x3),c in zip(Z_embedded, collection_dates)], columns=["dim0","dim1","dim2","date"])
# alt.data_transformers.disable_max_rows()
# alt.Chart(plotting_dict).mark_circle(size=60).encode(
#     x='dim0',
#     y='dim1',
#     # z='dim3',
#     color='date',
#     tooltip=['dim0', 'dim1', 'date']
# ).properties(
#     width=1000,
#     height=550
# ).interactive()


In [None]:
print(pca.components_.shape)
print(pca.singular_values_)

In [None]:
genome = X.view(X.shape[0], -1, len(ALPHABET)).cpu().numpy().astype("int")
genome = np.matmul(genome, np.arange(len(ALPHABET)))

In [None]:
from collections.abc import Iterable
def flatten(xs):
    for x in xs:
        if isinstance(x, Iterable) and not isinstance(x, (str, float)):
            yield from flatten(x)
        else:
            yield x

In [None]:
N = Z_embedded.shape[0]
print(N)
hamming = list(flatten([np.sum(np.not_equal(genome[i,:],genome[(i+1):,:]),axis=-1) for i in range(N-1)]))
euclid = list(flatten([np.linalg.norm((Z_embedded[(i+1):,:] - Z_embedded[i,:]), axis=-1) for i in range(N-1)]))

fig,arr = plt.subplots(1,1,figsize=(14,10))
arr.set_title("Hamming vs. Euclidean dist")
arr.scatter(hamming, euclid, alpha=0.3)
plt.show()

## GP regression

In [None]:
print(metadata.columns)

In [None]:
X = torch.tensor(new_dataset[ranges,:,:])
# X = X.to(DEVICE)
X = X.view(X.size(0), -1).to(DEVICE)
print("X shape")
print(new_dataset.shape)
print(X.shape)

pca = PCA(n_components=3, svd_solver="full")

Z_mean = None
Z_embedded = None
with torch.no_grad():
    # STANDARD
    Z_mean, Z_logvar = vae_model.encoder.forward(X)
    recon = vae_model.decoder.forward(Z_mean)
    recon = recon.view(recon.shape[0], -1).cpu()
    Z_mean = Z_mean.cpu()
    Z_std = torch.exp(0.5 * Z_logvar).cpu()
    # BEDFORD
    # recon, Z_mean, Z_logvar = vae_model.forward(X)
    # recon = recon.cpu().numpy()
    # Z_mean = Z_mean.cpu().numpy()

    pca.fit(Z_mean)
    Z_embedded = pca.transform(Z_mean - torch.mean(Z_mean))
    variances = pca.explained_variance_ratio_
    tot = np.sum(variances)
    print("\n",variances)
    print(f"total variance: {tot}")

In [None]:
metadata = pd.read_csv(f"{abspath}/data/all_data/all_metadata.tsv", sep="\t")
clade_labels = [metadata.loc[metadata.name == vals[i], "clade_membership"].values[0] for i in range(len(vals))]
dates = [metadata.loc[metadata.name == vals[i], "date"].values[0] for i in range(len(vals))]
dates = [datetime_from_numeric(x) for x in dates]

coords = [(x1,x2,x3,t,c) for (x1,x2,x3),t,c in zip(Z_embedded, dates,clade_labels)]
coords = pd.DataFrame(data=coords, columns=["dim0","dim1","dim2","time","clade"])

In [None]:
avg_coords = coords.groupby("time")[["dim0","dim1","dim2"]].median().resample("ME").median().dropna().reset_index()

In [None]:
ubound = len(avg_coords)

plt.ioff()
fig,ax = plt.subplots(1,3,figsize=(22,8))
for i,d in enumerate(["dim0","dim1","dim2"]):
    ax[i].plot(np.linspace(0,ubound,num=len(avg_coords)), avg_coords[d])
    ax[i].set_title(f"time vs. {d}")
plt.show()

In [None]:
x_vals = np.linspace(0,ubound,num=len(avg_coords)).astype("float32")[:,np.newaxis]
# import theano.tensor as tt
def build_coords_model(dim):
    y_vals = avg_coords[dim].values.astype('float32')
    print(x_vals.shape, y_vals.shape)

    with pm.Model() as model:
        # l = pm.HalfCauchy('l', beta=20)
        l = pm.Uniform('l', 0, 30)

        # Covariance function
        log_s2_f = pm.Uniform('log_s2_f', lower=-10, upper=5)
        s2_f = pm.Deterministic('s2_f', np.exp(log_s2_f))
        f_cov = s2_f * pm.gp.cov.ExpQuad(input_dim=1, ls=l)

        # Sigma = 1/lam
        s2_n = pm.HalfCauchy('s2_n', beta=5)

        gp = pm.gp.Latent(cov_func=f_cov)
        f = gp.prior("f",X=x_vals)

        df = 1 + pm.Gamma("df",alpha=2,beta=1)
        y_obs = pm.StudentT("y", mu=f, lam=1.0 / s2_n, nu=df, observed=y_vals)

        trace = pm.sample(draws=4000)
    return trace, gp

### Run and save regresssion

In [None]:
import pickle
import cloudpickle

ret_vals = [build_coords_model(d) for d in ["dim0","dim1","dim2"]]

GPs = [x[1] for x in ret_vals]
idata = [x[0] for x in ret_vals]

abspath = "."
dict_to_save = {x:(idata[i],GPs[i]) for i,x in enumerate(["dim0","dim1","dim2"])}
with open(f"{abspath}/king_regression_data.pkl","wb") as buff:
    cloudpickle.dump(dict_to_save, buff)

In [None]:
import arviz as az

for n,d in zip(["dim0","dim1","dim2"],idata):
    print(n)
    n_nonconverged = int(
        np.sum(az.rhat(d)[["l", "log_s2_f", "s2_n", "f_rotated_", "df"]].to_array() > 1.03).values
    )
    if n_nonconverged == 0:
        print("No Rhat values above 1.03, \N{check mark}")
    else:
        print(f"The MCMC chains for {n_nonconverged} RVs appear not to have converged.")

In [None]:
# plot the samples from the gp posterior with samples and shading
from pymc.gp.util import plot_gp_dist

idata = None
with open (f"{abspath}/king_regression_data.pkl", "rb") as buff:
    idata = pickle.load(buff)
    idata = [idata[x] for x in ["dim0","dim1","dim2"]]

for d,n in zip(idata,["dim0","dim1","dim2"]):
    fig = plt.figure(figsize=(16, 8))
    ax = fig.gca()
    disp_x_vals = np.arange(len(avg_coords))
    
    f_post = az.extract(d, var_names="f").transpose("sample", ...)
    plot_gp_dist(ax, f_post, disp_x_vals)
    ax.scatter(disp_x_vals, avg_coords[n])
    plt.show()

In [None]:
ubound = len(avg_coords)
new_x_vals = np.linspace(0,ubound+2,num=len(avg_coords)+2).astype("float32")[:,np.newaxis]

with pm.Model() as model:
    # add the GP conditional to the model, given the new X values
    f_pred = pm.gp.conditional("f_pred", X_new, jitter=1e-4)

    # Sample from the GP conditional distribution
    idata.extend(pm.sample_posterior_predictive(idata[0], var_names=["f_pred"]))

fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

f_pred = az.extract(idata.posterior_predictive, var_names="f_pred").transpose("sample", ...)
plot_gp_dist(ax, f_pred, X_new)
ax.scatter(disp_x_vals, avg_coords["dim0"])
plt.show()