In [1]:
import pandas as pd
import pybasilica.run as run
import torch
import pyro
import pyro.distributions as dist
import numpy as np
import seaborn as sns
import sklearn.metrics
import torch.nn.functional as F
from tqdm import tqdm
from pyro.distributions import constraints

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
m_g = pd.read_csv("test_datasets/counts_sbs.N150.G3.csv")
m_sbs = m_g.drop(["groups"], axis=1)
g_sbs = m_g["groups"].tolist() 
cosmic_sbs = pd.read_csv("test_datasets/COSMIC_filt.csv", index_col=0) 

In [3]:
m_g = pd.read_csv("test_datasets/counts_dbs.N150.G3.csv")
m_dbs = m_g.drop(["groups"], axis=1)
g_dbs = m_g["groups"].tolist()
cosmic_dbs = pd.read_csv("test_datasets/COSMIC_dbs.csv", index_col=0) 

In [146]:
obj_sbs = run.fit(
    x=m_sbs, 
    k_list=[3,4], 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=2000, 
    # cluster=[3],
    dirichlet_prior=True,
    beta_fixed=cosmic_sbs.loc[["SBS1","SBS3","SBS5"]], 
    store_parameters = True, 
    seed_list=[30],
    nonparametric=False,
    store_fits=True, enumer="parallel"
    )


ELBO 48445.147023: 100%|██████████| 2000/2000 [00:07<00:00, 267.45it/s]
ELBO 48017.537012: 100%|██████████| 2000/2000 [00:08<00:00, 224.64it/s]


In [5]:
obj_dbs = run.fit(
    x=m_dbs, 
    k_list=3, 
    lr=0.005, 
    optim_gamma=0.1,
    n_steps=1000, 
    # cluster=[3],
    dirichlet_prior=True,
    beta_fixed=cosmic_dbs.loc[["DBS3","DBS5"]], 
    store_parameters=False, 
    seed_list=[30],
    store_fits=True
    )


ELBO 132427.058856: 100%|██████████| 1000/1000 [00:02<00:00, 339.58it/s]


In [138]:
input_alpha = [obj_sbs.params["alpha"], obj_dbs.params["alpha"]]
obj_clust = run.fit(
    alpha=input_alpha,
    lr=0.005, 
    # optim_gamma=0.1,
    n_steps=2000, 
    cluster=[5],
    store_parameters=False, 
    hyperparameters={"scale_factor_alpha":1,
                     "scale_factor_centroid":10},
    seed_list=[30],
    nonparametric=False,
    store_fits=True
    # enumer="sequential"
    )


ELBO -4518.112029: 100%|██████████| 2000/2000 [00:12<00:00, 164.50it/s]


In [139]:
obj_clust.params["scale_factor_centroid"] 

array(11.83005616)

In [140]:
fitted_grps = obj_clust.groups 
init_grps = obj_clust.init_params["init_clusters"] 
sklearn.metrics.normalized_mutual_info_score(fitted_grps, init_grps) 

0.762040473321298

In [141]:
print(obj_clust.init_params["pi"])
print(obj_clust.params["pi"]) 

[0.16474561 0.42632571 0.06390914 0.30485305 0.0401665 ]
[0.16903113 0.41488137 0.05082647 0.32840042 0.03686061]


In [None]:
print(obj_clust.groups) 
print(obj_clust.init_params["init_clusters"]) 

In [None]:
obj_clust.params["post_probs"] 

In [None]:
def find_C_for_variance(alpha, Y):
    """
    Find the smallest C such that the variance of the Dirichlet distribution
    with parameter alpha*C is above Y for at least one dimension.
    """
    def variance_condition(C):
        variances = dirichlet_variance(alpha, C)
        max_variance = max(variances)
        return max_variance - Y

    # Initial guess for C
    initial_C = 1.0

    # Solve for C
    C_solution = fsolve(variance_condition, initial_C)
    return C_solution[0]

# Calculate the smallest C for the given alpha and Y
Y = 20
C_solution = find_C_for_variance(alpha, Y)
C_solution

In [9]:
np.amax(np.array([12,3,4]))

12

In [57]:
from scipy.optimize import fsolve

def dirichlet_variance(alpha, alpha_hat):
    # alpha_hat = np.sum(alpha)
    num = alpha * (alpha_hat - alpha)
    denomin = alpha_hat**2 * (alpha_hat + 1)
    return num / denomin

def optim_fn(c, true_var, alpha, alpha_hat):
    variances = dirichlet_variance(alpha=alpha*c, alpha_hat=alpha_hat*c)
    return np.abs(true_var - variances) + 1.0


In [78]:
alpha = dist.Dirichlet(torch.ones(5)).sample().numpy()
true_var = np.ones(5) * 1e-4 
print(f"alpha = {alpha}, true_var = {true_var}")
sols = fsolve(func=optim_fn, x0=1.0, args=(true_var[0], alpha[0], np.sum(alpha))) 

alpha = [0.31264767 0.03801846 0.10291303 0.29218847 0.25423238], true_var = [0.0001 0.0001 0.0001 0.0001 0.0001]


  improvement from the last ten iterations.


In [87]:
def solver(target, alpha_hat, alpha_k):
    a = target*alpha_hat**3
    b = target*alpha_hat**2
    c = alpha_k**2 - alpha_k*alpha_hat

    d = np.sqrt(b**2 - 4*a*c)
    xs = np.array([(-b + d) / (2*a), (-b - d) / (2*a)])
    return np.amax(xs)

c = solver(target=true_var[0], alpha_hat=1, alpha_k=alpha[0])

In [89]:
dirichlet_variance(alpha=alpha, alpha_hat=alpha.sum()) 

array([0.10744955, 0.01828653, 0.04616097, 0.10340718, 0.09479913],
      dtype=float32)