In [1]:
import pandas as pd
import pybasilica.run as run
import torch
import pyro
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
m_g = pd.read_csv("test_datasets/counts_sbs.N150.G3.csv")
m_sbs = m_g.drop(["groups"], axis=1)
g_sbs = m_g["groups"].tolist()
cosmic_sbs = pd.read_csv("test_datasets/COSMIC_filt.csv", index_col=0)

In [None]:
m_g = pd.read_csv("test_datasets/counts_dbs.N150.G3.csv")
m_dbs = m_g.drop(["groups"], axis=1)
g_dbs = m_g["groups"].tolist()
cosmic_dbs = pd.read_csv("test_datasets/COSMIC_dbs.csv", index_col=0) 

In [29]:
dn_sbs = torch.tensor(cosmic_sbs.loc[["SBS6","SBS17b"]].values, dtype=torch.float64)
ref_sbs = torch.tensor(cosmic_sbs.loc[["SBS1","SBS2","SBS5"]].values, dtype=torch.float64)
k_denovo = dn_sbs.shape[0]
k_fixed = ref_sbs.shape[0]

def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

with pyro.plate("beta_d_plate", k_denovo):
    pi_beta = pyro.sample("beta", pyro.distributions.Beta(torch.ones(k_fixed, dtype=torch.float64), 0.6).to_event(1))
    print(pi_beta)
    pi = mix_weights(pi_beta)

print(pi.shape)
print(pi)

tensor([[0.8228, 0.9394, 0.6845],
        [0.7946, 0.4966, 0.5945]], dtype=torch.float64)
torch.Size([2, 4])
tensor([[0.8228, 0.1665, 0.0074, 0.0034],
        [0.7946, 0.1020, 0.0615, 0.0419]], dtype=torch.float64)


In [36]:
beta_fixed_cum = pyro.distributions.Dirichlet(torch.ones(5)).sample()

In [38]:
rev = torch.ones(5) - beta_fixed_cum
rev / torch.sum(rev)


tensor([0.2100, 0.0892, 0.2479, 0.2105, 0.2425])

In [82]:
torch.cat((torch.ones(5) * 100, torch.ones(1)))

tensor([100., 100., 100., 100., 100.,   1.])

In [97]:
obj_sbs = run.fit(
    x=m_sbs, 
    k_list=[2], 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=10, 
    # cluster=6,
    dirichlet_prior=True,
    beta_fixed=cosmic_sbs.loc[["SBS1","SBS5"]], 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.6, "alpha_conc":100,
                     "scale_factor_alpha":5000, "scale_factor_centroid":5000, "scale_tau":0},
    enforce_sparsity = True, 
    reg_weight=0., 
    store_parameters = True, 
    seed_list=[92,30],
    nonparametric=True,
    store_fits=True
    )


tensor([0.3868, 0.4930], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2753, 0.2351], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2604, 0.1426], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2426, 0.1438], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2506, 0.1412], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2802, 0.1398], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.3898, 0.2314], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2575, 0.1835], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.3593, 0.3533], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2672, 0.1846], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2464, 0.1418], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2100, 0.1514], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.1966, 0.2244], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.1558, 0.1685], dtype=torch.float64, grad_fn=<SumBackward1>)
tensor([0.2434, 0.19

In [104]:
obj_sbs.init_params["beta_dn_param"]

Unnamed: 0,A[C>A]A,A[C>A]C,A[C>A]G,A[C>A]T,A[C>G]A,A[C>G]C,A[C>G]G,A[C>G]T,A[C>T]A,A[C>T]C,...,T[T>A]G,T[T>A]T,T[T>C]A,T[T>C]C,T[T>C]G,T[T>C]T,T[T>G]A,T[T>G]C,T[T>G]G,T[T>G]T
D1,0.026399,0.008559,0.011029,0.007494,0.003059,0.007532,0.011988,0.008269,0.020798,0.017555,...,0.002774,0.003429,0.011589,0.009749,0.001417,0.004271,0.016829,0.002012,0.009513,0.020003
D2,0.013991,0.026576,0.015232,0.003289,0.005713,0.023301,0.010336,0.001697,0.020441,0.009896,...,0.000811,0.015997,0.004745,0.002808,0.015485,0.015374,0.004541,0.019371,0.013401,0.007315


In [101]:
pyro.param("beta_weights")

tensor([[0.2734, 0.3623, 0.3643],
        [0.3783, 0.2133, 0.4084]], dtype=torch.float64, grad_fn=<DivBackward0>)

In [99]:
obj_sbs.params["beta_w"] 

tensor([[0.2734, 0.3623, 0.3643],
        [0.3783, 0.2133, 0.4084]], dtype=torch.float64)

In [100]:
obj_sbs.gradient_norms.keys() 

dict_keys(['beta_denovo', 'alpha'])

In [42]:
obj_sbs.params["alpha"].sum(axis=1)

0      1.0
1      1.0
2      1.0
3      1.0
4      1.0
      ... 
145    1.0
146    1.0
147    1.0
148    1.0
149    1.0
Length: 150, dtype: float32

In [9]:
k_dn = 2
k_f = 3
n_samples = 5
beta_weights = pyro.distributions.Dirichlet(torch.ones(k_dn, k_f+1)).sample()
alpha_star = pyro.distributions.Dirichlet(torch.ones(n_samples, k_dn)).sample()
print("beta weights\n", beta_weights)
print("alpha star\n", alpha_star)

beta weights
 tensor([[0.5773, 0.1750, 0.0487, 0.1990],
        [0.0189, 0.2324, 0.1488, 0.5999]])
alpha star
 tensor([[0.4076, 0.5924],
        [0.0937, 0.9063],
        [0.3129, 0.6871],
        [0.5145, 0.4855],
        [0.9945, 0.0055]])


In [36]:
beta_weights[1,2]

tensor(0.1488)

In [13]:
alpha = torch.zeros((n_samples, k_dn+k_f))

for n in range(n_samples):
    for j in range(k_dn):
        for r in range(k_f):
            alpha[n, r] += torch.sum(alpha_star[n,j]) * beta_weights[j,r]
        
        for d in range(k_f, k_f+k_dn):
            alpha[n, d] += torch.sum(alpha_star[n,j]) * beta_weights[j,-1]

print(alpha)

tensor([[0.2465, 0.2090, 0.1080, 0.4365, 0.4365],
        [0.0712, 0.2271, 0.1394, 0.5624, 0.5624],
        [0.1936, 0.2145, 0.1174, 0.4745, 0.4745],
        [0.3062, 0.2029, 0.0973, 0.3937, 0.3937],
        [0.5742, 0.1753, 0.0492, 0.2012, 0.2012]])


In [None]:
obj_sbs.params["beta_w"]

In [None]:
obj_sbs.params["beta_d"]

In [None]:
obj_dbs = run.fit(
    x=m_dbs, 
    k_list=3, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=10, 
    # cluster=6, 
    dirichlet_prior=True,
    beta_fixed=cosmic_dbs.loc[["DBS4"]], 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    enforce_sparsity = True, 
    reg_weight=0., 
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
alpha_sbs = obj_sbs.params["alpha"] 
alpha_dbs = obj_dbs.params["alpha"] 

In [None]:
input = [alpha_sbs, alpha_dbs] 
input_tensor = [torch.tensor(alpha_sbs.values), torch.tensor(alpha_dbs.values)]
max_shape = max([i.shape[1] for i in input_tensor])
# stacked = torch.stack(input_tensor)

In [None]:
mixture = run.fit(
    alpha=input, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=3000,
    cluster=5, 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
import torch.nn.functional as F
def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    print("beta =", beta)
    beta1m_cumprod = (1 - beta).cumprod(-1)
    print("beta1m_cumprod =", beta1m_cumprod)
    res1 = F.pad(beta, (0, 1), value=1)
    res2 = F.pad(beta1m_cumprod, (1, 0), value=1)
    res = res1 * res2
    print(f"res1 = {res1}, res2 = {res2}, res = {res}\n")
    return res


In [None]:
cluster = 6
with pyro.plate("beta_plate", cluster-1):
    pi_beta = pyro.sample(f"beta", pyro.distributions.Beta(1, 1.1755e-36))
    # pi_beta = torch.tensor([1.1755e-36, 2.1648e-18, 1.1755e-36, 6.6389e-33, 1.1755e-36])
    print("pi_beta =", pi_beta)
    pi = mix_weights(pi_beta)

print(pi)

In [None]:
beta_star = torch.zeros(k_denovo, 96, dtype=torch.float64) 
for i in range(k_denovo):
    tmp_sbs = torch.cat((ref_sbs, dn_sbs[i].unsqueeze(0)))
    beta_star[i] = pi[i].unsqueeze(0).matmul(tmp_sbs) 

In [None]:
pyro.distributions.Gamma(0.01, 0.01).sample((5,))

In [None]:
pyro.distributions.Dirichlet(torch.ones(5)).sample()

In [None]:
(1 - pyro.distributions.Beta(1, 1e-10).sample((cluster-1,))).cumprod(-1)

In [None]:
pi = torch.zeros((10,))
pi[:5] = 5
pi 

In [None]:
alpha_centr = mixture[0].params["alpha_prior"]
print(alpha_centr) 

In [None]:
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_sbs)) 
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_dbs)) 

In [None]:
print(obj_sbs.params["scale_factor_centroid"])
print(obj_sbs.params["scale_factor_alpha"]) 

In [None]:
obj_sbs.params

In [None]:
obj_sbs.train_params[6]["scale_factor_centroid"]

In [None]:
obj_sbs.params["pi_conc0"] 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.likelihoods)), y=obj_sbs.likelihoods) 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.losses)), y=obj_sbs.losses) 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_centroid_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_centroid_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_alpha_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_alpha_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_param"])), y=obj_sbs.gradient_norms["pi_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_conc0_param"])), y=obj_sbs.gradient_norms["pi_conc0_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha"])), y=obj_sbs.gradient_norms["alpha"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["beta_denovo"])), y=obj_sbs.gradient_norms["beta_denovo"])
except: print()

In [None]:
pd.DataFrame(np.array(obj_sbs.init_params["alpha_prior_param"]), columns=obj_sbs.params["alpha"].columns).plot.bar(stacked=True, legend=False) 

In [None]:
try: pd.DataFrame(np.array(obj_sbs.params["alpha_prior"]), columns=obj_sbs.params["alpha_prior"].columns).plot.bar(stacked=True, legend=False) 
except Exception as e: print() 

In [None]:
try:
    for gid in set(np.array(obj_sbs.groups)):
        tmp = [i for i, v in enumerate(obj_sbs.groups) if v == gid]
        # tmp = [i for i, v in enumerate(obj_sbs.groups) if (v == gid and i in idxs)]
        if len(tmp) == 0: continue
        pd.DataFrame(np.array(obj_sbs.params["alpha"]), columns=obj_sbs.params["alpha"].columns, 
                     index=obj_sbs.params["alpha"].index).iloc[tmp].plot.bar(stacked=True)
except Exception as e:
    print(e)
    obj_sbs.alpha.plot.bar(stacked=True, legend=False) 


In [None]:
try:
    for sbs in pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).index:
        pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).loc[[sbs]].transpose().plot.bar()
except Exception as e:
    print(e)