In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from nnTreeVB.data import evolve_seqs_full_homogeneity                                              
from nnTreeVB.data import build_tree_from_nwk
from nnTreeVB.data import SeqCollection
from nnTreeVB.data import build_msa_categorical

from nnTreeVB.models import VB_nnTree
# from nnTreeVB.models.vb_models.vb_nntree import VB_nnTree



import time
import math
import random
import copy
from pprint import pprint
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F

import matplotlib.pyplot as plt
# import matplotlib.ticker as ticker

# import logomaker as lm
# import pandas as pd
# import seaborn as sns

%matplotlib inline

# torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

In [3]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning)

## Alignment simulation

In [4]:
verbose = True

In [5]:
# Data preparation
alignment_len = 1000

str_tree = "(((1:0.3,2:0.7)N1:0.15,3:0.2)N2:0.1,4:0.4);"
# str_tree = "(((tx1:0.3,tx2:0.7)N1:0.15,tx3:0.2)N2:0.1,tx4:0.4);"
# str_tree = "(tx2:1.0411,tx4:0.8121,(tx1:1.1242,tx3:0.9130)N1:1.1095);"
# str_tree = "((1:0.3,2:0.7)N1:0.15,3:0.2,4:0.4);"
# str_tree = "((tx2:0.3,tx1:0.7)N1:0.15,tx3:0.2);"
# str_tree = "(1:0.1,2:0.2,3:0.3,4:0.4);"

#            "AG"  "AC"  "AT"  "GC"  "GT" "CT"
sim_rates = [0.16, 0.05, 0.16, 0.09, 0.3, 0.24]
evo_rates = [0.16, 0.05, 0.16, 0.09, 0.3, 0.24]

#             A     C    G     T
sim_freqs = [0.1, 0.45, 0.3, 0.15]
#             A     G    C     T
evo_freqs = [0.1, 0.3, 0.45, 0.15]


ete_tree, taxa, nodes = build_tree_from_nwk(str_tree)
logl_tree = copy.copy(ete_tree)

all_seqdict = evolve_seqs_full_homogeneity(
        str_tree,
        fasta_file=None,
        nb_sites=alignment_len,
        subst_rates=sim_rates,
        state_freqs=sim_freqs,
        return_anc=True,
        verbose=verbose)

sequences = [all_seqdict[s] for s in taxa]

Evolving new sequences with the amazing Pyvolve for None


In [6]:
device = torch.device("cpu")
# # device = torch.device("cuda")

In [7]:
gtr_freqs = torch.tensor(evo_freqs)
print("\nFrequencies")
print(gtr_freqs)
print(gtr_freqs.sum())

gtr_rates = torch.tensor(evo_rates) # AG, AC, AT, GC, GT, CT
print("\nRelative rates")
print(gtr_rates)
print(gtr_rates.sum())


Frequencies
tensor([0.1000, 0.3000, 0.4500, 0.1500])
tensor(1.)

Relative rates
tensor([0.1600, 0.0500, 0.1600, 0.0900, 0.3000, 0.2400])
tensor(1.)


In [8]:
true_branches = torch.zeros(len(ete_tree.get_descendants()))
# print(true_branches.shape[0])

for node in ete_tree.traverse("postorder"):

    if node.rank < true_branches.shape[0]:
#         print(node.name, node.rank, node.dist)
        true_branches[node.rank] = node.dist

true_branches = true_branches.unsqueeze(-1)
print("\ntrue_branches")
print(true_branches)
print(true_branches.sum())


true_branches
tensor([[0.3000],
        [0.7000],
        [0.2000],
        [0.4000],
        [0.1500],
        [0.1000]])
tensor(1.8500)


In [15]:
true_branches.shape

torch.Size([6, 1])

In [9]:
motifs_cats = build_msa_categorical(sequences)
X = torch.from_numpy(motifs_cats.data)
V = X.clone().detach()
V_counts = torch.ones(V.shape[0]).detach()
# X_counts = torch.ones(X.shape[0])
X, X_counts = X.unique(dim=0, return_counts=True)

print(X.shape)
print(X_counts)

torch.Size([178, 4, 4])
tensor([  4,   5,   1,   5,   3,   2,   5,   1,   1,  15,   4,   1,   2,  14,
          5,   1,   6,   3,   5,   1,   3,   4,   3,   4,   1,   2,   8,   4,
          2,   6,   2,   9,   1,   2,   1,   2,   1,   1,   1,   2,   1,   1,
          1,   2,   1,   6,   3,   2,  37,   2,   3,   4,   2,   2,   1,   3,
         15,  14,   1,  14, 175,  20,   2,   3,  14,  12,   2,   5,   1,   1,
          2,   2,   4,   4,  27,   6,   1,   6,   5,  14,   5,   2,   3,   1,
          1,   7,   1,   1,   1,   1,   1,   1,   3,   2,   1,   1,   7,   1,
          1,   3,   7,  17,   1,   1,   3,   2,   4,   1,   2,  14,  12,   2,
          5,  11,  24,   1,   2,   3,   6,   8,   6,   3,  10,   7,   1,  13,
         18,  73,   9,   2,   1,   3,   2,   1,   1,   2,   1,   1,   5,   6,
          1,   3,   1,   1,   1,   1,   2,   4,   4,   1,   3,   1,   1,   1,
          3,   1,   1,   1,   5,   1,   2,   1,   1,   1,   1,   3,   1,   1,
          2,  13,   1,   1,   3,   1,   

## True log likelihood

In [10]:
from nnTreeVB.models.evo_models import pruning
from nnTreeVB.models.evo_models import build_GTR_transition_matrix

tm = build_GTR_transition_matrix(true_branches.unsqueeze(0), 
                                 gtr_rates.unsqueeze(0), 
                                 gtr_freqs.unsqueeze(0))

true_lls = (pruning(logl_tree, X.unsqueeze(0), tm, gtr_freqs.unsqueeze(0)) * X_counts).sum()
true_lls

tensor(-3992.7161)

## nnTreeVB model initialization

In [68]:
# Hyper parameters
seed = 7353453
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

x_dim = 4
# Number of sequences
m_dim = len(ete_tree.get_leaf_names())
# Number of internal nodes
a_dim = len(ete_tree.get_descendants()) - m_dim + 1
# Number of branches
b_dim = len(ete_tree.get_edges()) - 1

h_dim = 32
nb_layers = 3

ancestor_prior_hp = (torch.ones(4)/4).tolist()
# ancestor_prior_hp = torch.tensor([f_A, f_G, f_C, f_T])

# branch_prior_hp = torch.tensor([0.1, 0.1])
# branch_prior_hp = torch.tensor([0.1, 1.])
branch_prior_hp = [0.1]*b_dim

# branch_prior_hp = torch.tensor(get_lognorm_params(0.01, 0.01))
# branch_prior_hp = torch.tensor(get_lognorm_params(*compute_branch_mean_std(b_str)))

tl_prior_hp = torch.tensor([1., 1.]).tolist()

kappa_prior_hp = torch.tensor([0.1, 0.1]).tolist()

rates_prior_hp = torch.ones(6).tolist() # Dirichlet
# rates_prior_hp = torch.ones(6)/6 # Cat
# rates_prior_hp = torch.tensor([m_AG, m_AC, m_AT, m_CG, m_GT, m_CT]) # AG, AC, AT, GC, GT, CT
freqs_prior_hp = torch.ones(4).tolist() # Dirichlet
# freqs_prior_hp = torch.tensor([f_A, f_G, f_C, f_T])

# alpha_kl = 0.0001
print(m_dim)

4


In [82]:
evoModel = VB_nnTree(
    x_dim,
    m_dim,
    b_dim,
    a_dim,
    subs_model="gtr",  # jc69 | k80 | hky | gtr
    predict_ancestors=False,
    
#     # branch lengths follows gamma
#     b_encoder_type="gamma_ind",
#     b_init_distr=[0.1, 0.1],
#     b_hp=branch_prior_hp,

#     # branch lengths fixed
#     b_encoder_type="fixed",
#     b_init_distr=true_branches.detach(),

    ## Compound branch lengths
    b_encoder_type="dirichlet_ind",
    b_init_distr=[1.]*b_dim,
    b_hp=branch_prior_hp,
    # Tree lengths
#     t_encoder_type="gamma_ind",
#     t_init_distr=[1., 1.],
#     t_hp=tl_prior_hp,

    # Tree length fixed
    t_encoder_type="fixed",
    t_init_distr=torch.tensor([1.85]),
    
    # kappa
#     k_encoder_type="gamma_nn_ind",
#     k_init_distr=[1., 0.1],
#     k_hp=kappa_prior_hp,
    
#     k_encoder_type="fixed",
#     k_init_distr=torch.tensor([1.5]),
    
    # rates
    r_encoder_type="dirichlet_ind",
    r_init_distr=[0.1]*6,
    r_hp=rates_prior_hp,
    
    # frequencies
    f_encoder_type="dirichlet_ind",
    f_init_distr=[1.]*4,
    f_hp=freqs_prior_hp,

#     f_encoder_type="fixed",
#     f_init_distr=gtr_freqs.detach(),

    #
    h_dim=h_dim,
    nb_layers=nb_layers,
    device=device
)

In [83]:
# Fitting parameters
n_epochs = 1000
learning_rate = 0.0005
weight_decay = 0.000001

nb_samples = 100
sample_temp=0.1

## nnTreeVB model fitting

In [84]:
r = evoModel.fit(
    copy.copy(ete_tree),
    X, 
    X_counts,
    latent_sample_size=nb_samples,
    max_iter=n_epochs,
    optim="adam",
    optim_learning_rate=learning_rate,
    optim_weight_decay=weight_decay,
    keep_fit_history=True,
    verbose=verbose
)

0m 0s	 Train Epoch: 20 	 ELBO: -7245.417	 Lls -7212.325	 KLs 35262.516

ValueError: Expected parameter concentration (Tensor of shape (6,)) of distribution Dirichlet(concentration: torch.Size([6])) to satisfy the constraint IndependentConstraint(GreaterThan(lower_bound=0.0), 1), but found invalid values:
tensor([nan, nan, nan, nan, nan, nan], grad_fn=<ExpBackward0>)

In [None]:
s = evoModel.sample(
    ete_tree, 
    V, 
    V_counts,
    latent_sample_size=nb_samples
)

In [73]:
print(s["b"].mean(0))
print(s["b"].mean(0).sum())

[0.17812043 0.3619667  0.15700266 0.12714994 0.06660503 0.10915533]
1.0000001


In [75]:
print(s["bt"].mean(0))
print(s["bt"].mean(0).sum())

[[0.32952273]
 [0.66963845]
 [0.29045495]
 [0.23522724]
 [0.12321934]
 [0.20193739]]
1.8500001


In [76]:
print(true_branches)
print(true_branches.sum())


tensor([[0.3000],
        [0.7000],
        [0.2000],
        [0.4000],
        [0.1500],
        [0.1000]])
tensor(1.8500)


In [77]:
s["r"].mean(0)

KeyError: 'r'

In [57]:
gtr_rates

tensor([0.1600, 0.0500, 0.1600, 0.0900, 0.3000, 0.2400])

In [48]:
s["f"].mean(0)

array([0.10000002, 0.2999997 , 0.4500004 , 0.14999986], dtype=float32)