In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nnTreeVB.data import evolve_seqs_full_homogeneity                                              
from nnTreeVB.data import build_tree_from_nwk
from nnTreeVB.data import SeqCollection
from nnTreeVB.data import build_msa_categorical

from nnTreeVB.models import VB_nnTree
# from nnTreeVB.models.vb_models.vb_nntree import VB_nnTree



import time
import math
import random
from pprint import pprint
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F

import matplotlib.pyplot as plt
# import matplotlib.ticker as ticker

# import logomaker as lm
# import pandas as pd
# import seaborn as sns

%matplotlib inline

# torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning)

## Alignment simulation

In [None]:
verbose = True

In [None]:
# Data preparation
alignment_len = 1000

str_tree = "(((1:0.3,2:0.7)N1:0.15,3:0.2)N2:0.1,4:0.4);"
# str_tree = "(((tx1:0.3,tx2:0.7)N1:0.15,tx3:0.2)N2:0.1,tx4:0.4);"
# str_tree = "(tx2:1.0411,tx4:0.8121,(tx1:1.1242,tx3:0.9130)N1:1.1095);"
# str_tree = "((1:0.3,2:0.7)N1:0.15,3:0.2,4:0.4);"
# str_tree = "((tx2:0.3,tx1:0.7)N1:0.15,tx3:0.2);"
# str_tree = "(1:0.1,2:0.2,3:0.3,4:0.4);"

#            "AG"  "AC"  "AT"  "GC"  "GT" "CT"
sim_rates = [0.16, 0.05, 0.16, 0.09, 0.3, 0.24]
evo_rates = [0.16, 0.05, 0.16, 0.09, 0.3, 0.24]

#             A     C    G     T
sim_freqs = [0.1, 0.45, 0.3, 0.15]
#             A     G    C     T
evo_freqs = [0.1, 0.3, 0.45, 0.15]


ete_tree, taxa, nodes = build_tree_from_nwk(str_tree)

all_seqdict = evolve_seqs_full_homogeneity(
        str_tree,
        fasta_file=None,
        nb_sites=alignment_len,
        subst_rates=sim_rates,
        state_freqs=sim_freqs,
        return_anc=True,
        verbose=verbose)

sequences = [all_seqdict[s] for s in taxa]

In [None]:
device = torch.device("cpu")
# # device = torch.device("cuda")

In [None]:
gtr_freqs = torch.tensor(evo_freqs)
print("\nFrequencies")
print(gtr_freqs)
print(gtr_freqs.sum())

gtr_rates = torch.tensor(evo_rates) # AG, AC, AT, GC, GT, CT
print("\nRelative rates")
print(gtr_rates)
print(gtr_rates.sum())

In [None]:
true_branches = torch.zeros(len(ete_tree.get_descendants()))
# print(true_branches.shape[0])

for node in ete_tree.traverse("postorder"):

    if node.rank < true_branches.shape[0]:
#         print(node.name, node.rank, node.dist)
        true_branches[node.rank] = node.dist

true_branches = true_branches.unsqueeze(-1)
print("\ntrue_branches")
print(true_branches)
print(true_branches.sum())

In [None]:
motifs_cats = build_msa_categorical(sequences)
X = torch.from_numpy(motifs_cats.data)
V = X.clone().detach()
V_counts = torch.ones(V.shape[0]).detach()
# X_counts = torch.ones(X.shape[0])
X, X_counts = X.unique(dim=0, return_counts=True)

print(X.shape)
print(X_counts)

## nnTreeVB model initialization

In [None]:
# Hyper parameters
seed = 798
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

x_dim = 4
# Number of sequences
m_dim = len(ete_tree.get_leaf_names())
# Number of internal nodes
a_dim = len(ete_tree.get_descendants()) - m_dim + 1
# Number of branches
b_dim = len(ete_tree.get_edges()) - 1

h_dim = 32
nb_layers = 3

ancestor_prior_hp = (torch.ones(4)/4).tolist()
# ancestor_prior_hp = torch.tensor([f_A, f_G, f_C, f_T])

# branch_prior_hp = torch.tensor([0.1, 0.1])
# branch_prior_hp = torch.tensor([0.1, 1.])
branch_prior_hp = [1.]*b_dim

# branch_prior_hp = torch.tensor(get_lognorm_params(0.01, 0.01))
# branch_prior_hp = torch.tensor(get_lognorm_params(*compute_branch_mean_std(b_str)))

tl_prior_hp = torch.tensor([1., 1.]).tolist()

kappa_prior_hp = torch.tensor([0.1, 0.1]).tolist()

rates_prior_hp = torch.ones(6).tolist() # Dirichlet
# rates_prior_hp = torch.ones(6)/6 # Cat
# rates_prior_hp = torch.tensor([m_AG, m_AC, m_AT, m_CG, m_GT, m_CT]) # AG, AC, AT, GC, GT, CT
freqs_prior_hp = torch.ones(4).tolist() # Dirichlet
# freqs_prior_hp = torch.tensor([f_A, f_G, f_C, f_T])

# alpha_kl = 0.0001
print(m_dim)

In [None]:
evoModel = VB_nnTree(
    x_dim,
    m_dim,
    b_dim,
    a_dim,
    subs_model="hky",  # jc69 | k80 | hky | gtr
    predict_ancestors=False,
    
    # branch lengths
#     b_encoder_type="gamma_ind",
#     b_init_distr=[0.1, 0.1],
#     b_hp=branch_prior_hp,

    ## Compound branch lengths
    b_encoder_type="dirichlet_ind",
    b_init_distr=[1.]*b_dim,
    b_hp=branch_prior_hp,
    # Tree lengths
    t_encoder_type="gamma_ind",
    t_init_distr=[1., 1.],
    t_hp=tl_prior_hp,
    
    # kappa
    k_encoder_type="gamma_nn_ind",
    k_init_distr=[1., 0.1],
    k_hp=kappa_prior_hp,
    
    # rates
    r_encoder_type="dirichlet_ind",
    r_init_distr=[1.]*6,
    r_hp=rates_prior_hp,
    
    # frequencies
    f_encoder_type="dirichlet_ind",
    f_init_distr=[1.]*4,
    f_hp=freqs_prior_hp,
    
    #
    h_dim=h_dim,
    nb_layers=nb_layers,
    device=device
)

In [None]:
# Fitting parameters
n_epochs = 1000
learning_rate = 0.005
weight_decay = 0.00001

nb_samples = 10
sample_temp=0.1

## nnTreeVB model fitting

In [None]:
r = evoModel.fit(
    ete_tree, 
    X, 
    X_counts,
    latent_sample_size=nb_samples,
    max_iter=n_epochs,
    optim="adam",
    optim_learning_rate=learning_rate,
    optim_weight_decay=weight_decay,
    keep_fit_history=True,
    verbose=verbose
)

In [None]:
r.keys()