In [1]:
import pandas as pd
import pybasilica.run as run
import torch
import pyro
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

In [2]:
m_g = pd.read_csv("test_datasets/counts_sbs.N150.G3.csv")
m_sbs = m_g.drop(["groups"], axis=1)
g_sbs = m_g["groups"].tolist()
cosmic_sbs = pd.read_csv("test_datasets/COSMIC_filt.csv", index_col=0)

In [80]:
m_g = pd.read_csv("test_datasets/counts_dbs.N150.G3.csv")
m_dbs = m_g.drop(["groups"], axis=1)
g_dbs = m_g["groups"].tolist()
cosmic_dbs = pd.read_csv("test_datasets/COSMIC_dbs.csv", index_col=0) 

In [27]:
dn_sbs = torch.tensor(cosmic_sbs.loc[["SBS6","SBS17b"]].values, dtype=torch.float64)
ref_sbs = torch.tensor(cosmic_sbs.loc[["SBS1","SBS2","SBS5"]].values, dtype=torch.float64)
k_denovo = dn_sbs.shape[0]
k_fixed = ref_sbs.shape[0]

def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

with pyro.plate("beta_d_plate", k_denovo):
    pi_beta = pyro.sample("beta", pyro.distributions.Beta(torch.ones(k_fixed, dtype=torch.float64), 0.6).to_event(1))
    print(pi_beta)
    pi = mix_weights(pi_beta)

print(pi.shape)
print(pi.dtype)

tensor([[0.9456, 0.9812, 0.7598],
        [0.9397, 0.7446, 0.9962]], dtype=torch.float64)
torch.Size([2, 4])
torch.float64


In [78]:
obj_sbs = run.fit(
    x=m_sbs, 
    k_list=[0], 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=3000, 
    # cluster=6,
    dirichlet_prior=True,
    beta_fixed=cosmic_sbs.loc[["SBS1","SBS5"]], 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":5000, "scale_factor_centroid":5000, "scale_tau":0},
    enforce_sparsity = True, 
    reg_weight=0., 
    store_parameters = True, 
    seed_list=[92,30],
    nonparametric=True,
    store_fits=True
    )


In [82]:
obj_dbs = run.fit(
    x=m_dbs, 
    k_list=3, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=10, 
    # cluster=6, 
    dirichlet_prior=True,
    beta_fixed=cosmic_dbs.loc[["DBS4"]], 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    enforce_sparsity = True, 
    reg_weight=0., 
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
alpha_sbs = obj_sbs.params["alpha"] 
alpha_dbs = obj_dbs.params["alpha"] 

In [None]:
input = [alpha_sbs, alpha_dbs] 
input_tensor = [torch.tensor(alpha_sbs.values), torch.tensor(alpha_dbs.values)]
max_shape = max([i.shape[1] for i in input_tensor])
# stacked = torch.stack(input_tensor)

In [None]:
mixture = run.fit(
    alpha=input, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=3000,
    cluster=5, 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
import torch.nn.functional as F
def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    print("beta =", beta)
    beta1m_cumprod = (1 - beta).cumprod(-1)
    print("beta1m_cumprod =", beta1m_cumprod)
    res1 = F.pad(beta, (0, 1), value=1)
    res2 = F.pad(beta1m_cumprod, (1, 0), value=1)
    res = res1 * res2
    print(f"res1 = {res1}, res2 = {res2}, res = {res}\n")
    return res


In [None]:
cluster = 6
with pyro.plate("beta_plate", cluster-1):
    pi_beta = pyro.sample(f"beta", pyro.distributions.Beta(1, 1.1755e-36))
    # pi_beta = torch.tensor([1.1755e-36, 2.1648e-18, 1.1755e-36, 6.6389e-33, 1.1755e-36])
    print("pi_beta =", pi_beta)
    pi = mix_weights(pi_beta)

print(pi)

In [None]:
beta_star = torch.zeros(k_denovo, 96, dtype=torch.float64) 
for i in range(k_denovo):
    tmp_sbs = torch.cat((ref_sbs, dn_sbs[i].unsqueeze(0)))
    beta_star[i] = pi[i].unsqueeze(0).matmul(tmp_sbs) 

In [None]:
pyro.distributions.Gamma(0.01, 0.01).sample((5,))

In [None]:
pyro.distributions.Dirichlet(torch.ones(5)).sample()

In [None]:
(1 - pyro.distributions.Beta(1, 1e-10).sample((cluster-1,))).cumprod(-1)

In [None]:
pi = torch.zeros((10,))
pi[:5] = 5
pi 

In [None]:
alpha_centr = mixture[0].params["alpha_prior"]
print(alpha_centr) 

In [None]:
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_sbs)) 
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_dbs)) 

In [None]:
print(obj_sbs.params["scale_factor_centroid"])
print(obj_sbs.params["scale_factor_alpha"]) 

In [None]:
obj_sbs.params

In [None]:
obj_sbs.train_params[6]["scale_factor_centroid"]

In [None]:
obj_sbs.params["pi_conc0"] 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.likelihoods)), y=obj_sbs.likelihoods) 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.losses)), y=obj_sbs.losses) 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_centroid_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_centroid_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_alpha_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_alpha_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_param"])), y=obj_sbs.gradient_norms["pi_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_conc0_param"])), y=obj_sbs.gradient_norms["pi_conc0_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha"])), y=obj_sbs.gradient_norms["alpha"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["beta_denovo"])), y=obj_sbs.gradient_norms["beta_denovo"])
except: print()

In [None]:
pd.DataFrame(np.array(obj_sbs.init_params["alpha_prior_param"]), columns=obj_sbs.params["alpha"].columns).plot.bar(stacked=True, legend=False) 

In [None]:
try: pd.DataFrame(np.array(obj_sbs.params["alpha_prior"]), columns=obj_sbs.params["alpha_prior"].columns).plot.bar(stacked=True, legend=False) 
except Exception as e: print() 

In [None]:
try:
    for gid in set(np.array(obj_sbs.groups)):
        tmp = [i for i, v in enumerate(obj_sbs.groups) if v == gid]
        # tmp = [i for i, v in enumerate(obj_sbs.groups) if (v == gid and i in idxs)]
        if len(tmp) == 0: continue
        pd.DataFrame(np.array(obj_sbs.params["alpha"]), columns=obj_sbs.params["alpha"].columns, 
                     index=obj_sbs.params["alpha"].index).iloc[tmp].plot.bar(stacked=True)
except Exception as e:
    print(e)
    obj_sbs.alpha.plot.bar(stacked=True, legend=False) 


In [None]:
try:
    for sbs in pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).index:
        pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).loc[[sbs]].transpose().plot.bar()
except Exception as e:
    print(e)