In [1]:
import pandas as pd
import pybasilica.run as run
import torch
import pyro
import pyro.distributions as dist
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch.nn.functional as F
from tqdm import tqdm
from pyro.distributions import constraints

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

def model(data):
    with pyro.plate("beta_plate", T-1):
        beta = pyro.sample("beta",  dist.Beta(1, alpha))

    with pyro.plate("mu_plate", T):
        mu = pyro.sample("mu",  dist.MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))

    with pyro.plate("data", N):
        z = pyro.sample("z",  dist.Categorical(mix_weights(beta)))
        pyro.sample("obs",  dist.MultivariateNormal(mu[z], torch.eye(2)), obs=data)

def guide(data):
    kappa = pyro.param('kappa', lambda: dist.Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: dist.MultivariateNormal(torch.zeros(2), 3 * torch.eye(2)).sample([T]))
    phi = pyro.param('phi', lambda: dist.Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1):
        q_beta = pyro.sample("beta",  dist.Beta(torch.ones(T-1), kappa))

    with pyro.plate("mu_plate", T):
        q_mu = pyro.sample("mu",  dist.MultivariateNormal(tau, torch.eye(2)))

    with pyro.plate("data", N):
        z = pyro.sample("z",  dist.Categorical(phi))

In [None]:
data = torch.cat((dist.MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),
                  dist.MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),
                  dist.MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),
                  dist.MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))

N = data.shape[0]
T = 6
optim = pyro.optim.Adam({"lr": 0.05})
svi = pyro.infer.SVI(model, guide, optim, loss=pyro.infer.Trace_ELBO())
losses = []

def train(num_iterations):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = 0.1
train(1000)

# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), 
                                              torch.mean(pyro.param("phi").detach(), dim=0))

alpha = 1.5
train(1000)

# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_15, Bayes_Weights_15 = truncate(alpha, pyro.param("tau").detach(), 
                                              torch.mean(pyro.param("phi").detach(), dim=0))


In [None]:
print(Bayes_Weights_01)
print(Bayes_Weights_15) 

In [2]:
m_g = pd.read_csv("test_datasets/counts_sbs.N150.G3.csv")
m_sbs = m_g.drop(["groups"], axis=1)
g_sbs = m_g["groups"].tolist() 
cosmic_sbs = pd.read_csv("test_datasets/COSMIC_filt.csv", index_col=0) 

In [3]:
m_g = pd.read_csv("test_datasets/counts_dbs.N150.G3.csv")
m_dbs = m_g.drop(["groups"], axis=1)
g_dbs = m_g["groups"].tolist()
cosmic_dbs = pd.read_csv("test_datasets/COSMIC_dbs.csv", index_col=0) 

In [4]:
dn_sbs = torch.tensor(cosmic_sbs.loc[["SBS6","SBS17b"]].values, dtype=torch.float64)
ref_sbs = torch.tensor(cosmic_sbs.loc[["SBS1","SBS2","SBS5"]].values, dtype=torch.float64)
# k_denovo = dn_sbs.shape[0]
# k_fixed = ref_sbs.shape[0]
clusters = 6

def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

conc = torch.tensor([0.36432, 0.383313626, 0.257168233, 0.005969330, 0.017175399])

pi_beta = torch.zeros(len(conc))
for i in range(len(conc)):
    pi_beta[i] = pyro.sample("beta", pyro.distributions.Beta(1, conc[i])) 

pi = mix_weights(pi_beta)

print(pi_beta)
print(torch.round(pi * 100))

with pyro.plate("beta_d_plate", len(conc)):
    pi_beta2 = pyro.sample("beta", pyro.distributions.Beta(1, conc)) 

print(pi_beta2)
pi2 = mix_weights(pi_beta2)
print(torch.round(pi2 * 100))

tensor([0.9031, 0.7285, 0.6458, 1.0000, 1.0000])
tensor([90.,  7.,  2.,  1.,  0.,  0.])
tensor([0.0183, 0.7627, 0.9953, 1.0000, 1.0000])
tensor([ 2., 75., 23.,  0.,  0.,  0.])


In [None]:
sns.displot(dist.Gamma(torch.tensor(0.2, dtype=torch.float64), 0.1).sample((1000,)).tolist())

In [None]:
torch.round(dist.Dirichlet(torch.tensor([0.3, 0.5, 0.3])*10).sample((10,)), decimals=3)


In [None]:
torch.mean(torch.tensor([1.,2.,3.]), dim=0)

In [4]:
obj_sbs = run.fit(
    x=m_sbs, 
    k_list=[3,4], 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=2000, 
    # cluster=[3],
    dirichlet_prior=True,
    beta_fixed=cosmic_sbs.loc[["SBS1","SBS5"]], 
    store_parameters = True, 
    seed_list=[30],
    nonparametric=False,
    store_fits=True, enumer="parallel"
    )


ELBO 49954.599999: 100%|██████████| 2000/2000 [00:06<00:00, 319.57it/s]
ELBO 46872.140686: 100%|██████████| 2000/2000 [00:06<00:00, 306.38it/s]


In [5]:
obj_dbs = run.fit(
    x=m_dbs, 
    k_list=3, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=1000, 
    # cluster=[3],
    dirichlet_prior=True,
    beta_fixed=cosmic_dbs.loc[["DBS3","DBS5"]], 
    store_parameters = True, 
    seed_list=[30],
    nonparametric=False,
    store_fits=True
    )


ELBO 132427.058856: 100%|██████████| 1000/1000 [00:02<00:00, 333.48it/s]


In [54]:
input_alpha = [obj_sbs.params["alpha"], obj_dbs.params["alpha"]]
obj_clust = run.fit(
    alpha=input_alpha,
    lr=0.005, 
    # optim_gamma=0.1,
    n_steps=1000, 
    cluster=[1],
    store_parameters=False, 
    hyperparameters={"scale_factor_alpha":1,
                     "scale_factor_centroid":1000},
    seed_list=[30],
    nonparametric=False,
    store_fits=True, 
    # enumer="sequential"
    )


Bar desc:   0%|          | 0/1000 [00:00<?, ?it/s]

ELBO -13432.927245: 100%|██████████| 1000/1000 [00:02<00:00, 361.99it/s]


In [55]:
fitted_grps = obj_clust.groups 
init_grps = obj_clust.init_params["init_clusters"] 
sklearn.metrics.normalized_mutual_info_score(fitted_grps, init_grps) 

InvalidParameterError: The 'labels_pred' parameter of normalized_mutual_info_score must be an array-like. Got None instead.

In [49]:
obj_clust.init_params

{'alpha_prior': array([[3.09356898e-01, 7.23796897e-03, 1.64192199e-04, 4.61899303e-03,
         2.57856399e-03, 6.76043391e-01, 1.94365196e-02, 1.13202147e-02,
         4.19796288e-01, 4.70661074e-01, 7.87859038e-02, 1.17549435e-38],
        [2.32092232e-01, 2.54470054e-02, 7.35509038e-01, 2.73319194e-03,
         1.04476203e-05, 4.20813356e-03, 2.59973705e-01, 1.30015807e-02,
         6.36480972e-02, 8.61725062e-02, 5.77204108e-01, 1.17549435e-38],
        [2.45857686e-01, 1.59027260e-02, 6.10817969e-01, 6.32758811e-03,
         9.62918699e-02, 2.48020757e-02, 4.94775325e-02, 5.21353371e-02,
         5.36355555e-01, 2.40847722e-01, 1.21183850e-01, 1.17549435e-38]]),
 'pi': array([0.32037163, 0.48040223, 0.19922613]),
 'latent_class': array([[1.00000000e+00, 5.39649705e-13, 4.00152592e-17],
        [8.38880986e-03, 9.91611183e-01, 1.13019205e-10],
        [4.73527275e-02, 3.20487261e-01, 6.32160187e-01],
        [3.91969894e-04, 7.67805716e-07, 9.99607146e-01],
        [1.00000000e+00

In [47]:
print(obj_clust.init_params["pi"])
print(obj_clust.params["pi"]) 

[0.32037163 0.48040223 0.19922613]
[0.24832506 0.60452432 0.14715062]


In [48]:
print(obj_clust.groups) 
print(obj_clust.init_params["init_clusters"]) 

[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 0 2 2 2 2 2 2 2 1 2 2 2 2 2
 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 0 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 2 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1
 1 1]
[2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 1 1 1 1 1
 1 2 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1
 1 1]


In [104]:
obj_clust.params["post_probs"] 

Unnamed: 0,0,1,2
0,2.813763e-06,4.598606e-07,0.999997
1,3.278157e-06,8.228630e-07,0.999996
2,1.964098e-05,7.995612e-07,0.999980
3,9.595536e-06,5.839855e-07,0.999990
4,1.915339e-06,2.463153e-06,0.999996
...,...,...,...
145,6.162418e-07,1.232728e-04,0.999876
146,8.909166e-07,7.659577e-05,0.999923
147,9.787764e-07,2.687341e-05,0.999972
148,1.087666e-06,1.957569e-04,0.999803


In [None]:
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
import numpy as np

# model
def model(data, K):
    N = len(data)
    hidden_dim = 2

    # As in any clustering algorithm, the mixing proportions are the assignment probabilities of the cells.
    # We sample the mixture weights from a Dirichlet distribution

    weights = pyro.sample('mixture_weights', dist.Dirichlet(torch.ones(K)))

    with pyro.plate('probabilities', K):  # cat_probs.size=(K,hidden_dim)
        cat_probs = pyro.sample("cat_probabilities", dist.Dirichlet(torch.ones(hidden_dim)))

    cat_vector = torch.tensor(np.arange(hidden_dim) + 1, dtype=torch.float)
    scale = pyro.sample("scale", dist.Gamma(5, 1))

    # likelihood
    with pyro.plate("data", N):
        pyro.factor("lk", log_lik(data, scale, weights, cat_vector, cat_probs, K))


# loglikelihood
def log_lik(data, scale, weights, cat_vector, cat_probs, K):
    N = len(data)
    hidden_dim = 2
    data = data.reshape(N, 1, 1)
    mean = scale * cat_vector.reshape(1, 1, hidden_dim)
    weights = weights.reshape(1, K, 1)
    cat_probs = cat_probs.reshape(1, K, hidden_dim)

    lk = torch.log(weights) + dist.Poisson(mean).log_prob(data) + torch.log(cat_probs)
    c = torch.max(torch.max(lk, dim=-1).values, dim=-1).values.reshape(N, 1, 1)
    log_lik = c + torch.log(torch.exp(lk - c).sum(dim=-1).sum(dim=-1)).reshape(N, 1, 1)

    return log_lik.sum()


# guide
def guide(data, K):
    pi = pyro.param("q_mixture_weights", create_params(data, K)["mixture_weights"], constraint=constraints.simplex)
    scale = pyro.param("q_scale", create_params(data, K)["scale"], constraint=constraints.positive)
    probs = pyro.param("q_cat_probabilities", create_params(data, K)["cat_probabilities"],
                       constraint=constraints.simplex)

    print(probs)

    with pyro.plate('probabilities', K):  # cat_probs.size=(K,hidden_dim)
        cat_prob = pyro.sample("cat_probabilities", dist.Delta(probs).to_event(1))
    pyro.sample('mixture_weights', dist.Delta(pi).to_event(1))
    pyro.sample("scale", dist.Delta(scale))


# initialization
def create_params(data, K):
    from sklearn.cluster import KMeans

    N = len(data)
    hidden_dim = 2

    kmeans = KMeans(init="random",
                    n_clusters=K,
                    n_init=10,
                    max_iter=300,
                    random_state=42)

    kmeans.fit(data.reshape(N, 1))

    categorical = torch.tensor(np.arange(hidden_dim) + 1, dtype=torch.float)
    mean_categorical = categorical.mean()
    mean_clusters = kmeans.cluster_centers_.mean()

    # initialize scale
    scale = mean_clusters / mean_categorical
    mean = scale * categorical.reshape(1, hidden_dim)

    # initialize mixing proportions and cat_probs
    Prob = torch.ones(K, hidden_dim) / hidden_dim
    mixing_proportions = torch.ones(K) / K

    for k in range(K):
        subset = []
        for i in range(len(data)):
            if kmeans.labels_[i] == k:
                subset.append(data[i])

        n = len(subset)
        dataset = torch.tensor(subset).reshape(n, 1)
        mixing_proportions[k] = len(subset) / len(data)

        p = dist.Poisson(mean).log_prob(dataset)
        c = torch.max(p, dim=0).values.reshape(1, hidden_dim)
        p = c + torch.log(torch.exp(p - c).sum(dim=0)).reshape(1, hidden_dim)
        c = torch.max(p, dim=-1).values
        Norm = c + torch.log(torch.exp(p - c).sum(dim=-1))
        Prob[k] = torch.exp(p - Norm)

    params = {"mixture_weights": mixing_proportions, "scale": scale, "cat_probabilities": Prob}

    return params


# cluster assignments
def cluster_assignments(data, K, pi, scale, categorical):
    N = len(data)

    mean = scale * categorical.reshape(1, K)
    pi = pi.reshape(1, K)
    data = data.reshape(N, 1)

    Prob = torch.log(pi) + dist.Poisson(mean).log_prob(data)  # Prob.size=(N,K)
    c = torch.max(Prob, dim=-1).values.reshape(N, 1)
    Norm = c + torch.log(torch.exp(Prob - c).sum(dim=-1)).reshape(N, 1)  # Norm.size= (N,1)

    return torch.exp(Prob - Norm)


#inference
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

def inference(model,guide,K,data,lr=0.05,num_steps=500):

    pyro.clear_param_store()  # always clear the store before the inference

    # learning global parameters
    adam_params = {"lr": lr}
    optimizer = Adam(adam_params)
    elbo = Trace_ELBO()

    svi = SVI(model, guide, optimizer, loss=elbo)

    # inference
    # do gradient steps
    for step in range(num_steps):
        loss = svi.step(data, K)  # get the loss function after a gradient step
        if step % 10 == 0:
            print("loss =", loss)  # check the progress

    print("final loss =", svi.evaluate_loss(data, K))

    parameters = {}
    for key in pyro.get_param_store().get_all_param_names():
        parameters.update({key: torch.tensor(pyro.param(key))})

    
    cat_probs = parameters["q_cat_probabilities"]
    scale = parameters["q_scale"]
    pi = parameters["q_mixture_weights"]
    categorical = torch.argmax(cat_probs, dim=-1) + 1
    assignment_probs = cluster_assignments(data, K, pi, scale, categorical)
    assignments = torch.argmax(assignment_probs, dim=-1)
    parameters.update({"categorical_variable":categorical, 
                       "assignment_probs":assignment_probs,
                       "assignments":assignments})

    # print("parameters:", parameters)
    # print("categorical variable:", categorical)
    # print("assignment_probs:", assignment_probs)
    # print("assignments", assignments)

    return parameters


In [None]:
#visualization
import matplotlib.pyplot as plt

def plotting(data, K, assignments, bins=20):
    fig, axes = plt.subplots(1, 1, figsize=(12, 10), sharey=True)
    for k in range(K):
        subset = []
        for i in range(len(data)):
            if assignments[i] == k:
                subset.append(data[i])

        dataset = torch.tensor(subset)
        axes.hist(dataset, bins=bins, fill=True)


#data generator
def data_generator(N, scale, categorical, mixing_proportions):
    data = []
    means = scale * categorical

    for n in range(N):
        cluster = dist.Categorical(mixing_proportions).sample()
        data.append(dist.Poisson(means[cluster]).sample())

    data = torch.tensor(data)
    return data


In [None]:
#example
N = 1000
scale = 5
categorical = torch.tensor([1.,5.])
K = len(categorical)
mixing_proportions = torch.tensor([0.3,0.7])
data = data_generator(N, scale, categorical, mixing_proportions)

inferred_parameters = inference(model, guide, K, data, lr=0.05, num_steps=3)
assignments = inferred_parameters["assignments"]
plotting(data, K, assignments, bins=20) 

In [None]:
print(inferred_parameters.keys())

In [None]:
inferred_parameters["q_cat_probabilities"].shape

In [None]:
import numpy as np
from scipy.stats import dirichlet

# Define the parameters
alpha = [1, 1, 1]  # Adjust alpha values as needed
power = 2.0  # Adjust the power parameter

# Sample from the standard Dirichlet distribution
sample = np.random.dirichlet(alpha)

# Apply the transformation
sample_away_from_mode = sample ** (1 / power)

print("Sampled value away from the mode:", sample_away_from_mode)
print("Sampled:", sample)


In [None]:
import torch
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Define the parameters of the Dirichlet distribution
alpha = torch.tensor([1.0, 5.0, 1.0])  # Replace with your alpha values

# Sample from the Dirichlet distribution
sampled_value = pyro.sample("sampled_value", dist.Dirichlet(alpha))

print("Sampled value away from the mode:", sampled_value)

# Find the mode of the Dirichlet distribution
mode = torch.argmax(alpha)

print("Mode", mode)

# Create a mask to zero out the mode value
mask = torch.ones_like(sampled_value)
mask[mode] = 0

# Zero out the mode value
sampled_value = sampled_value * mask

# Renormalize to make it a valid probability distribution
sampled_value = sampled_value / sampled_value.sum()

print("Sampled value away from the mode:", sampled_value)


In [None]:
# a = a_orig**(3)
fig2, ax2 = plt.subplots()
fig3, ax3 = plt.subplots()
sns.histplot(a[:,0].tolist(), ax=ax2)
sns.histplot(a[:,1].tolist(), ax=ax3)
ax2.set_xlim(0,1)
ax3.set_xlim(0,1)

In [None]:
torch.sum(fixed, dim=0)

In [None]:
fixed = torch.tensor(obj_sbs.beta_fixed.values)
beta_w = torch.tensor(obj_sbs.params["beta_w"].values)
denovo = torch.tensor(obj_sbs.params["beta_d"].values)
cum_weights = torch.ones((obj_sbs.k_denovo, obj_sbs.k_fixed))/obj_sbs.k_fixed

fixed_cum = obj_sbs._get_unique_beta_stick_breaking(beta_fixed=fixed, beta_denovo=None, beta_weights=cum_weights)
fixed_cum = obj_sbs._norm_and_clamp(fixed_cum)

print(torch.sum((fixed_cum * (torch.abs(fixed_cum - denovo)))) * torch.tensor(obj_sbs.x.values).sum())
print(torch.sum((fixed_cum * (torch.abs(fixed_cum - denovo)))) * obj_sbs.x.shape[0] * obj_sbs.x.shape[1])


In [None]:
obj_sbs.losses

In [None]:
len(obj_sbs.train_params)

In [None]:
pyro.distributions.Dirichlet(fixed_cum*1000).log_prob(denovo) 

In [None]:
obj_sbs.gradient_norms.keys()

In [None]:
## self.x.sum() * torch.sum(beta_fixed_cum * (1 - torch.abs(beta_fixed_cum - beta_denovo)))
obj_sbs.params["beta_w"] 

In [None]:
## self.x.sum() * torch.sum(beta_fixed_cum * (torch.abs(beta_fixed_cum - beta_denovo)))
obj_sbs.params["beta_w"] 

In [None]:
## self.x.sum() * torch.sum(beta_fixed_cum * (torch.abs(beta_fixed_cum - beta_denovo)))
obj_sbs.params["beta_w"] 

In [None]:
## self.n_samples * self.contexts * pyro.distributions.Dirichlet(beta_fixed_cum*1000).to_event(1).log_prob(beta_denovo))
obj_sbs.params["beta_w"] 

In [None]:
pyro.param("beta_weights")

In [None]:
obj_sbs.params["beta_w"] 

In [None]:
obj_sbs.gradient_norms.keys() 

In [None]:
obj_sbs.params["alpha"].sum(axis=1)

In [None]:
k_dn = 2
k_f = 3
n_samples = 5
beta_weights = pyro.distributions.Dirichlet(torch.ones(k_dn, k_f+1)).sample()
alpha_star = pyro.distributions.Dirichlet(torch.ones(n_samples, k_dn)).sample()
print("beta weights\n", beta_weights)
print("alpha star\n", alpha_star)

In [None]:
beta_weights[1,2]

In [None]:
alpha = torch.zeros((n_samples, k_dn+k_f))

for n in range(n_samples):
    for j in range(k_dn):
        for r in range(k_f):
            alpha[n, r] += torch.sum(alpha_star[n,j]) * beta_weights[j,r]
        
        for d in range(k_f, k_f+k_dn):
            alpha[n, d] += torch.sum(alpha_star[n,j]) * beta_weights[j,-1]

print(alpha)

In [None]:
obj_sbs.params["beta_w"]

In [None]:
obj_sbs.params["beta_d"]

In [None]:
obj_dbs = run.fit(
    x=m_dbs, 
    k_list=3, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=10, 
    # cluster=6, 
    dirichlet_prior=True,
    beta_fixed=cosmic_dbs.loc[["DBS4"]], 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    enforce_sparsity = True, 
    reg_weight=0., 
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
alpha_sbs = obj_sbs.params["alpha"] 
alpha_dbs = obj_dbs.params["alpha"] 

In [None]:
input = [alpha_sbs, alpha_dbs] 
input_tensor = [torch.tensor(alpha_sbs.values), torch.tensor(alpha_dbs.values)]
max_shape = max([i.shape[1] for i in input_tensor])
# stacked = torch.stack(input_tensor)

In [None]:
mixture = run.fit(
    alpha=input, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=3000,
    cluster=5, 
    hyperparameters={"alpha_sigma":.15, "alpha_p_sigma":1., "alpha_p_conc0":0.6, 
                     "alpha_p_conc1":0.6, "alpha_rate":1., "pi_conc0":0.5, "alpha_conc":100,
                     "scale_factor_alpha":10000, "scale_factor_centroid":1000, "scale_tau":0},
    store_parameters = True, 
    seed_list=[92],
    nonparametric=True,
    store_fits=True
    )


In [None]:
import torch.nn.functional as F
def mix_weights(beta):
    '''
    Function used for the stick-breaking process.
    '''
    print("beta =", beta)
    beta1m_cumprod = (1 - beta).cumprod(-1)
    print("beta1m_cumprod =", beta1m_cumprod)
    res1 = F.pad(beta, (0, 1), value=1)
    res2 = F.pad(beta1m_cumprod, (1, 0), value=1)
    res = res1 * res2
    print(f"res1 = {res1}, res2 = {res2}, res = {res}\n")
    return res


In [None]:
cluster = 6
with pyro.plate("beta_plate", cluster-1):
    pi_beta = pyro.sample(f"beta", pyro.distributions.Beta(1, 1.1755e-36))
    # pi_beta = torch.tensor([1.1755e-36, 2.1648e-18, 1.1755e-36, 6.6389e-33, 1.1755e-36])
    print("pi_beta =", pi_beta)
    pi = mix_weights(pi_beta)

print(pi)

In [None]:
beta_star = torch.zeros(k_denovo, 96, dtype=torch.float64) 
for i in range(k_denovo):
    tmp_sbs = torch.cat((ref_sbs, dn_sbs[i].unsqueeze(0)))
    beta_star[i] = pi[i].unsqueeze(0).matmul(tmp_sbs) 

In [None]:
pyro.distributions.Gamma(0.01, 0.01).sample((5,))

In [None]:
pyro.distributions.Dirichlet(torch.ones(5)).sample()

In [None]:
(1 - pyro.distributions.Beta(1, 1e-10).sample((cluster-1,))).cumprod(-1)

In [None]:
pi = torch.zeros((10,))
pi[:5] = 5
pi 

In [None]:
alpha_centr = mixture[0].params["alpha_prior"]
print(alpha_centr) 

In [None]:
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_sbs)) 
print(sklearn.metrics.normalized_mutual_info_score(mixture.groups, g_dbs)) 

In [None]:
print(obj_sbs.params["scale_factor_centroid"])
print(obj_sbs.params["scale_factor_alpha"]) 

In [None]:
obj_sbs.params

In [None]:
obj_sbs.train_params[6]["scale_factor_centroid"]

In [None]:
obj_sbs.params["pi_conc0"] 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.likelihoods)), y=obj_sbs.likelihoods) 

In [None]:
sns.scatterplot(x=range(len(obj_sbs.losses)), y=obj_sbs.losses) 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_centroid_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_centroid_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["scale_factor_alpha_param"])), 
                     y=obj_sbs.gradient_norms["scale_factor_alpha_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha_prior_param"])), y=obj_sbs.gradient_norms["alpha_prior_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_param"])), y=obj_sbs.gradient_norms["pi_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["pi_conc0_param"])), y=obj_sbs.gradient_norms["pi_conc0_param"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["alpha"])), y=obj_sbs.gradient_norms["alpha"]) 
except: print() 

In [None]:
try: sns.scatterplot(x=range(len(obj_sbs.gradient_norms["beta_denovo"])), y=obj_sbs.gradient_norms["beta_denovo"])
except: print()

In [None]:
pd.DataFrame(np.array(obj_sbs.init_params["alpha_prior_param"]), columns=obj_sbs.params["alpha"].columns).plot.bar(stacked=True, legend=False) 

In [None]:
try: pd.DataFrame(np.array(obj_sbs.params["alpha_prior"]), columns=obj_sbs.params["alpha_prior"].columns).plot.bar(stacked=True, legend=False) 
except Exception as e: print() 

In [None]:
try:
    for gid in set(np.array(obj_sbs.groups)):
        tmp = [i for i, v in enumerate(obj_sbs.groups) if v == gid]
        # tmp = [i for i, v in enumerate(obj_sbs.groups) if (v == gid and i in idxs)]
        if len(tmp) == 0: continue
        pd.DataFrame(np.array(obj_sbs.params["alpha"]), columns=obj_sbs.params["alpha"].columns, 
                     index=obj_sbs.params["alpha"].index).iloc[tmp].plot.bar(stacked=True)
except Exception as e:
    print(e)
    obj_sbs.alpha.plot.bar(stacked=True, legend=False) 


In [None]:
try:
    for sbs in pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).index:
        pd.concat((obj_sbs.params["beta_f"], obj_sbs.params["beta_d"])).loc[[sbs]].transpose().plot.bar()
except Exception as e:
    print(e)