In [1]:
import sys 
sys.path += ["../src"]
import BC_leaders, BC_update
import numpy as np
from tqdm import tqdm
from time import time
import pickle
from glob import glob
from pyABC_ import pyabc
from scipy.special import expit as np_sigmoid

import os
from tempfile import gettempdir
from pyABC_.pyabc.sampler import SingleCoreSampler
from jax.scipy.special import expit as sigmoid
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from numpyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal, AutoBNAFNormal, AutoIAFNormal
from numpyro import distributions
import numpyro
from numpyro.optim import Adam
import jax.random as random
from datetime import timedelta
# numpyro.set_platform("gpu")
from diptest import dipstat
from scipy.stats import kurtosis, skew
import matplotlib.pyplot as plt



  from .autonotebook import tqdm as notebook_tqdm


In [2]:


def compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = True):
    # edges_iter = (edges_t for edges_t in edges)
    # Xt = jax.lax.stop_gradient(X0.copy())
    Xt = X0.copy()
    X_list = [Xt.copy()]
    
    N = len(Xt)
    
    while True:
        edges_t = next(edges_iter, None)
        if edges_t is None:
            break
        
        u,v,s_plus,s_minus = edges_t.T
        u,v = u.astype(int),v.astype(int)
        diff_X = Xt[u] - Xt[v]

        updates_plus = mu_plus * s_plus * diff_X
        updates_minus = (mu_minus * s_minus * diff_X) * is_backfire
        # print(updates_minus)
        # Xt = Xt.at[v].add(updates_plus - updates_minus).clip(1e-5, 1 - 1e-5)
        Xt[v] += updates_plus - updates_minus
        Xt[v] = np.clip(Xt[v], 1e-5, 1 - 1e-5)
        
        X_list.append(Xt.copy())

    return jnp.stack(X_list)


def initialize_training(X, edges, mu_plus, mu_minus, rho = 32):
    T, N = X.shape    
    u,v,s_plus,s_minus,t = BC_leaders.convert_edges_uvst(edges)
    s_plus, s_minus = jnp.float32(s_plus), jnp.float32(s_minus)

    X0 = np.array(X[0])
    edges_iter = (edge for edge in edges)
    X_bc = compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = False)
    edges_iter = (edge for edge in edges)
    X_back = compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = True)
    u,v,t = u.astype(int), v.astype(int), t.astype(int)

    diff_X_bc = X_bc[t,u] - X_bc[t,v]
    diff_X_back = X_back[t,u] - X_back[t,v]

    return {"u": u, "v": v, "s_plus": s_plus, "s_minus": s_minus, "t": t,
            "N": N, "T": T, "rho": rho,
            "diff_X_bc": diff_X_bc, "diff_X_back": diff_X_back}

def model(data):
    dim = 3
    dist = distributions.Normal(jnp.zeros(dim), jnp.ones(dim)).to_event(1)
    params = numpyro.sample("theta", dist)
    
    theta = params[:2]
    param_backfire = params[2:]
    epsilon_plus, epsilon_minus = sigmoid(theta) /  2 + jnp.array([0.,.5])

    ############
    diff_X_bc,diff_X_back,u,v,s_plus, s_minus,t, rho, N, T = [data[k] for k in ["diff_X_bc", "diff_X_back","u","v",
                                                                      "s_plus", "s_minus","t",
                                                                      "rho", "N", "T"]]
    
    
    backfire_sample = numpyro.sample("backfire", distributions.RelaxedBernoulli(probs = param_backfire, temperature = jnp.array([0.1])).to_event(1))
    is_backfire = backfire_sample[0]
        
    s_plus = jnp.array(s_plus)
    s_minus = jnp.array(s_minus)
 
    diff_X = (1 - is_backfire) * diff_X_bc + is_backfire * diff_X_back
    kappas_plus = BC_leaders.kappa_plus_from_epsilon(epsilon_plus, diff_X, rho, with_jax = True)
    kappas_minus = BC_leaders.kappa_minus_from_epsilon(epsilon_minus, diff_X, rho, with_jax = True)
    kappas_ = jnp.concatenate([kappas_minus, kappas_plus])
    s = jnp.concatenate([s_minus, s_plus])

    with numpyro.plate("data", s.shape[0]):
        numpyro.sample("obs", distributions.Bernoulli(probs = kappas_), obs = s)

def train_svi(X, edges, mu_plus, mu_minus, guide_family = "normal", rho = 32,
              n_steps = 4000, intermediate_steps = None, lr = 0.01, 
              progress_bar = False, id = None, timeout = 3600):
    if intermediate_steps is None:
        intermediate_steps = n_steps
    
    if guide_family == "normal":
        guide = AutoNormal(model)
    if guide_family == "NF":
        guide = AutoBNAFNormal(model, num_flows = 1, hidden_factors = (8,8))
        n_steps = int(n_steps / 2)
        intermediate_steps = int(intermediate_steps / 2)
    
    data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
    optimizer = Adam(step_size = lr)
    svi = SVI(model, guide, optimizer, loss = TraceGraph_ELBO())
    res = []
    last_state = None

    tot_time = 0
    
    for _ in range(int(n_steps / intermediate_steps)):
        t0 = time()
        svi_results = svi.run(random.PRNGKey(0), intermediate_steps, data, init_state = last_state, progress_bar = progress_bar)
        t1 = time()
        tot_time += t1 - t0

        theta_samples = guide.sample_posterior(random.PRNGKey(0), svi_results.params, sample_shape = (200,))
        param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)
        
        res_analysis = {"param_mean": param_mean,
                        "param_std": param_std,
                        "backfire_mean": backfire_mean,
                        "backfire_std": backfire_std,
                        "tot_time": tot_time,
                        "n_simulations": None,
                        "method": "svi" + guide_family,
                        "n_steps": intermediate_steps * (_ + 1),
                        "n_samples": None,
                        "id": id
                        }
        res.append(res_analysis)

        last_state = svi_results.state
        if tot_time > timeout:
            break

    return res


def train_mcmc(X, edges, mu_plus, mu_minus, intermediate_samples = None, rho = 32, num_chains = 1,
               warmup_samples = None, n_samples = 400, progress_bar = False, id = None, timeout = 3600):
    if intermediate_samples is None:
        intermediate_samples = n_samples
    if warmup_samples is None:
        warmup_samples = intermediate_samples

    data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
    key = random.PRNGKey(0)
    mcmc = MCMC(NUTS(model), num_warmup = warmup_samples, num_chains = num_chains, 
                num_samples = intermediate_samples, progress_bar = progress_bar)
    res = []
    tot_time = 0
    for _ in range(int(n_samples / intermediate_samples)):
        t0 = time()
        mcmc.run(key, data)
        t1 = time()
        tot_time += t1 - t0

        mcmc.post_warmup_state = mcmc.last_state
        key = mcmc.post_warmup_state.rng_key
        
        mcmc_samples = mcmc.get_samples()
        param_mean, param_std, backfire_mean, backfire_std = analyse_samples(mcmc_samples)
        res.append({"param_mean": param_mean,
                    "param_std": param_std,
                    "backfire_mean": backfire_mean, 
                    "backfire_std": backfire_std,
                    "tot_time": tot_time,
                    "n_simulations": None,
                    "method": "mcmc",
                    "n_steps": None,
                    "n_samples": intermediate_samples * (_ + 1),
                    "id": id})
        if tot_time > timeout:
            break

    return res


def create_summary_statistics(X0, edges_iter, edge_per_t, parameters, mu_plus, mu_minus, rho):
    summary_statistics_list = []
    Xt = X0.copy()
    N = len(Xt)
    
    while True:
        edges_t = next(edges_iter, None)
        if edges_t is None:
            break
        is_backfire = parameters["theta2"]
        epsilon_plus,epsilon_minus = epsilons_from_theta(parameters, dict_theta = True, numpy = True)
        u,v,_,_ = edges_t.T
        u,v = u.astype(int),v.astype(int)
        diff_X = Xt[u] - Xt[v]
        # s_plus = ((np.random.rand(edge_per_t) < np_sigmoid(rho * (epsilon_plus - np.abs(diff_X))))) + 0
        # s_minus = ((np.random.rand(edge_per_t) < np_sigmoid(-rho * (epsilon_minus - np.abs(diff_X))))) + 0
        s_plus =  (np.abs(diff_X) < epsilon_plus) + 0
        s_minus = (np.abs(diff_X) > epsilon_minus) + 0

        updates_plus = mu_plus * s_plus * diff_X 
        updates_minus = mu_minus * s_minus * diff_X * is_backfire
        Xt[v] += updates_plus - updates_minus
        Xt[v] = np.clip(Xt[v], 1e-5, 1 - 1e-5)
            
        summary_statistics_list.append(np.concatenate([u[None,:],v[None,:],s_plus[None,:], s_minus[None,:]])[None,:])

    edges_sim = np.concatenate(summary_statistics_list).transpose(0,2,1)
    return {"s_plus_sum": edges_sim[:,:,-2].sum(axis = 1), 
            "s_minus_sum": edges_sim[:,:,-1].sum(axis = 1)}

    

def create_trajectory(X0, edges, parameters, mu_plus, mu_minus, rho):
    X0 = X0.copy()
    edges_iter = (edges_t for edges_t in edges)
    T, edge_per_t, _ = edges.shape
    summary_statistics = create_summary_statistics(X0, edges_iter, edge_per_t, parameters, mu_plus, mu_minus, rho)
    # summary_statistics = create_s_update_X(X0, edges_iter, edge_per_t, parameters, rho, [], [X0[None,:].copy()])
    return summary_statistics

def sim_trajectory_X0_edges(X0, edges, mu_plus, mu_minus, rho):
    return lambda parameters: create_trajectory(X0, edges, parameters, mu_plus, mu_minus, rho)




In [3]:
def train_abc(X, edges, mu_plus, mu_minus, populations_budget = 10, intermediate_populations = None,
              population_size = 200, rho = 32, id = None, timeout = 3600):
    if intermediate_populations is None:
        intermediate_populations = populations_budget
    
    T = len(X)
    res = []
    tot_time = 0
    model_abc = sim_trajectory_X0_edges(X[0], edges, mu_plus, mu_minus, rho)
    prior = pyabc.Distribution(
                theta0=pyabc.RV("norm", 0, 1),
                theta1=pyabc.RV("norm", 0, 1),
                theta2=pyabc.RV("rv_discrete", values = (np.arange(2), 0.5 * np.ones(2))))
    distance = pyabc.PNormDistance(2)
    obs = {"s_plus_sum": edges[:,:,-2].sum(axis = 1), 
           "s_minus_sum": edges[:,:,-1].sum(axis = 1)}
    abc = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
    db = "sqlite:///" + os.path.join(gettempdir(), f"{id}_is_backfire_test.db")
    history = abc.new(db, obs)
    run_id = history.id
    for _ in range(int(populations_budget / intermediate_populations)):
        abc_continued = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
        abc_continued.load(db, run_id)
        t0 = time()
        history = abc_continued.run(max_nr_populations = intermediate_populations,
                                    minimum_epsilon = T,
                                    # minimum_epsilon = 5 * (T ** (1/2)),
                                    max_walltime = timedelta(hours = 3))
        t1 = time()
        tot_time += (t1 - t0)
        theta_samples = jnp.array(history.get_distribution()[0])

        param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)
        res_analysis = {"param_mean": param_mean,
                        "param_std": param_std,
                        "backfire_mean": backfire_mean, 
                        "backfire_std": backfire_std,
                        "tot_time": tot_time,
                        "n_simulations": history.total_nr_simulations,
                        "method": "abc",
                        "n_steps": None,
                        "n_samples": None,
                        "id": id
                        }
        res.append(res_analysis)
        if tot_time > timeout:
            break
    return res


In [4]:
def count_interactions(X, edges):
    T, N = X.shape
    _,edge_per_t,_ = edges.shape
    
    pos_interactions_plus, pos_interactions_minus = edges[:,:,2].sum(), edges[:,:,3].sum()
    tot_interactions = (T - 1) * edge_per_t
    

    return {"pos_interactions_plus":pos_interactions_plus, 
            "pos_interactions_minus":pos_interactions_minus, 
            "tot_interactions":tot_interactions,
            "T": T, "N": N, "edge_per_t": edge_per_t,
            "var_X_end": X[-1].var(),
            "skew_X_end": skew(X[-1]),
            "kurtosis_X_end": kurtosis(X[-1]),
            "bimodality_X_end": dipstat(X[-1]),
            }



In [5]:
########## all ##############
def epsilons_from_theta(parameters, dict_theta = False, numpy = False):
    
    sigmoid_fn = np_sigmoid if numpy else sigmoid
    if dict_theta:
        epsilon_plus = sigmoid_fn(parameters["theta0"]) / 2
        epsilon_minus = sigmoid_fn(parameters["theta1"]) / 2 + .5
        
        return epsilon_plus,epsilon_minus
    elif len(parameters.shape) == 1:
        epsilon_plus,epsilon_minus,is_backfire = sigmoid(parameters) / jnp.array([2, 2, 1]) + jnp.array([0.,.5, 0.])
        return epsilon_plus,epsilon_minus,is_backfire
    else:
        trans_theta = (sigmoid(parameters) / 2 + jnp.array([0.,.5])).T
        return trans_theta



In [6]:
def analyse_samples(samples):
        param_samples = epsilons_from_theta(samples["theta"][:,:2], dict_theta = False, numpy = False).T
        param_mean, param_std = param_samples.mean(axis = 0), param_samples.std(axis = 0)
        backfire_mean = samples["backfire"].mean(axis = 0)
        backfire_std = samples["backfire"].std(axis = 0)
        
        return param_mean, param_std, backfire_mean, backfire_std

def analyse_results(epsilon_plus, epsilon_minus, mu_plus, mu_minus, is_backfire, backfire_mean, backfire_std,
                    param_mean, param_std, n_samples, n_steps, n_simulations, id,
                    tot_time, method):
    params = np.array([epsilon_plus, epsilon_minus])

    param_names = ["epsilon_plus", "epsilon_minus"]
    out = {
            "id": id,
            "mse_epsilon": ((params[:2] - param_mean[:2])**2).mean().item(), 
            "tot_time": tot_time,
            "n_steps": n_steps, 
            "n_samples": n_samples,
            "method": method,
            "n_simulations": n_simulations,
            "mu_plus": mu_plus, 
            "mu_minus": mu_minus,
            "is_backfire": is_backfire,
            "round_backfire": round(backfire_mean.item()),
            "backfire_mean": backfire_mean.item(), 
            "backfire_std": backfire_std.item(),
            }

    out.update({u + "_error": np.abs(params[k] - param_mean[k]) for k, u in enumerate(param_names)})
    out.update({u + "_mean": param_mean[k].item() for k, u in enumerate(param_names)})
    out.update({u + "_std": param_std[k].item() for k, u in enumerate(param_names)})
    out.update({u + "_real": params[k].item() for k, u in enumerate(param_names)})

    return out
        

def save_pickle(out, path):
    if path is not None:
        with open(path, "wb") as f:
            pickle.dump(out, f)



In [9]:
N, T, edge_per_t = 100, 34, 10
epsilon_plus = 0.2
epsilon_minus = 0.8
mu_plus, mu_minus = 0.1, 0.1
is_backfire = False
rho = 32

In [10]:
X, edges = BC_update.simulate_trajectory(N = N, T = T, edge_per_t = edge_per_t, 
                                                  epsilon_plus = epsilon_plus, epsilon_minus = epsilon_minus, 
                                                  mu_plus = mu_plus, mu_minus = mu_minus * is_backfire, rho = rho)  


In [11]:
analysis_data = count_interactions(X, edges)

### abc

In [41]:
population_size = 200
populations_budget = 20
intermediate_populations = 20

In [13]:
T = len(X)
res = []
tot_time = 0
model_abc = sim_trajectory_X0_edges(X[0], edges, mu_plus, mu_minus, rho)
prior = pyabc.Distribution(
            theta0=pyabc.RV("norm", 0, 1),
            theta1=pyabc.RV("norm", 0, 1),
            theta2=pyabc.RV("rv_discrete", values = (np.arange(2), 0.5 * np.ones(2))))
distance = pyabc.PNormDistance(2)
obs = {"s_plus_sum": edges[:,:,-2].sum(axis = 1), 
       "s_minus_sum": edges[:,:,-1].sum(axis = 1)}
abc = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
db = "sqlite:///" + os.path.join(gettempdir(), f"{id}_is_backfire_test.db")
history = abc.new(db, obs)
run_id = history.id
for _ in range(int(populations_budget / intermediate_populations)):
    abc_continued = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
    abc_continued.load(db, run_id)
    t0 = time()
    history = abc_continued.run(max_nr_populations = intermediate_populations,
                                minimum_epsilon = T,
                                # minimum_epsilon = 5 * (T ** (1/2)),
                                max_walltime = timedelta(hours = 3))
    

In [14]:
theta_samples = jnp.array(history.get_distribution()[0])


In [25]:
def analyse_samples(samples):
    param_samples = epsilons_from_theta(samples["theta"][:,:2], dict_theta = False, numpy = False).T
    param_mean, param_std = param_samples.mean(axis = 0), param_samples.std(axis = 0)
    backfire_mean = samples["backfire"].mean(axis = 0)
    backfire_std = samples["backfire"].std(axis = 0)
        
    return param_mean, param_std, backfire_mean, backfire_std


In [27]:
analyse_samples({"theta": theta_samples[:,:2], "backfire": theta_samples[:,2]})

(Array([0.20836131, 0.7762659 ], dtype=float32),
 Array([0.05554182, 0.08334755], dtype=float32),
 Array(0.49499997, dtype=float32),
 Array(0.49997497, dtype=float32))

In [None]:
param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)
    

### svi

In [42]:
guide = AutoNormal(model)

In [43]:
data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)

In [44]:
optimizer = Adam(step_size = 0.01)
svi = SVI(model, guide, optimizer, loss = TraceGraph_ELBO())
    

In [45]:
svi_results = svi.run(random.PRNGKey(0), 1000, data)

100%|██████████| 1000/1000 [00:06<00:00, 156.40it/s, init loss: 202.0720, avg. loss [951-1000]: 38.8567]


In [48]:
theta_samples = guide.sample_posterior(random.PRNGKey(0), svi_results.params, sample_shape = (200,))

In [67]:
theta_samples["backfire"].mean(axis = 0)[0].item()

0.46815863251686096

In [49]:
param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)

In [52]:
theta_samples

{'backfire': Array([[4.28798705e-01],
        [1.37343807e-02],
        [9.73539591e-01],
        [1.39971122e-01],
        [9.84072387e-01],
        [1.93138242e-01],
        [2.59950105e-03],
        [1.36527479e-01],
        [9.98240352e-01],
        [4.06109363e-01],
        [6.15140438e-01],
        [1.09412089e-01],
        [2.52987281e-03],
        [8.60712588e-01],
        [7.36232281e-01],
        [9.35697436e-01],
        [9.48804259e-01],
        [1.47793546e-01],
        [9.64626014e-01],
        [1.22292833e-02],
        [9.90873218e-01],
        [2.04008847e-01],
        [8.80107760e-01],
        [4.09331508e-02],
        [1.33425305e-02],
        [8.25155992e-03],
        [6.41632173e-03],
        [9.98303771e-01],
        [8.35782170e-01],
        [1.48697279e-03],
        [6.59540832e-01],
        [2.14003205e-01],
        [2.53981277e-02],
        [5.32608449e-01],
        [7.08801150e-01],
        [8.36271867e-02],
        [9.87271011e-01],
        [8.63143444e-01],


In [None]:
def train_svi(X, edges, mu_plus, mu_minus, guide_family = "normal", rho = 32,
              n_steps = 4000, intermediate_steps = None, lr = 0.01, 
              progress_bar = False, id = None, timeout = 3600):
    if intermediate_steps is None:
        intermediate_steps = n_steps
    
    if guide_family == "normal":
        guide = AutoNormal(model)
    if guide_family == "NF":
        guide = AutoBNAFNormal(model, num_flows = 1, hidden_factors = (8,8))
        n_steps = int(n_steps / 2)
        intermediate_steps = int(intermediate_steps / 2)
    
    data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
    optimizer = Adam(step_size = lr)
    svi = SVI(model, guide, optimizer, loss = TraceGraph_ELBO())
    res = []
    last_state = None

    tot_time = 0
    
    for _ in range(int(n_steps / intermediate_steps)):
        t0 = time()
        svi_results = svi.run(random.PRNGKey(0), intermediate_steps, data, init_state = last_state, progress_bar = progress_bar)
        t1 = time()
        tot_time += t1 - t0

        theta_samples = guide.sample_posterior(random.PRNGKey(0), svi_results.params, sample_shape = (200,))
        param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)
        
        res_analysis = {"param_mean": param_mean,
                        "param_std": param_std,
                        "backfire_mean": backfire_mean,
                        "backfire_std": backfire_std,
                        "tot_time": tot_time,
                        "n_simulations": None,
                        "method": "svi" + guide_family,
                        "n_steps": intermediate_steps * (_ + 1),
                        "n_samples": None,
                        "id": id
                        }
        res.append(res_analysis)

        last_state = svi_results.state
        if tot_time > timeout:
            break

    return res



In [None]:
def complete_experiment(N, T, edge_per_t, rho = 32,
                        method = "svinormal",
                        epsilon_plus = None, epsilon_minus = None, mu_plus = None, mu_minus = None,
                        n_steps = 1000, n_samples = 100, populations_budget = 10, num_chains = 1,
                        intermediate_steps = None, intermediate_samples = None, warmup_samples = None, 
                        intermediate_populations = None, population_size = 200, 
                        lr = 0.01, progress_bar = False, timeout = 25000, id = None, date = None, save_data = True
                        ):
    if len(glob(f"../data/is_backfire_{date}/X_{id}*")) > 0:
        X_file = glob(f"../data/is_backfire_{date}/X_{id}*")[0]
        edges_file = glob(f"../data/is_backfire_{date}/edges_{id}*")[0]
        X = np.load(X_file)
        edges = np.load(edges_file)
        
        _,_,epsilon_plus,epsilon_minus, mu_plus, mu_minus = [int(u) for u in X_file.split("/")[-1].split("_")[2:-1]]
        epsilon_plus, epsilon_minus, mu_plus, mu_minus = np.array([epsilon_plus, epsilon_minus, mu_plus, mu_minus]) / 100
    else:
        if epsilon_plus is None:
            epsilon_plus = round(np.random.randint(5) * 0.1 + 0.05, 4)
            epsilon_minus = round(np.random.randint(5) * 0.1 + 0.55, 4)
            mu_plus = round(np.random.randint(10) * 0.02 + 0.01, 4)
            mu_minus = round(np.random.randint(10) * 0.02 + 0.01, 4)
            is_backfire = np.random.randint(2)
    
        X, edges = BC_update.simulate_trajectory(N = N, T = T, edge_per_t = edge_per_t, 
                                                  epsilon_plus = epsilon_plus, epsilon_minus = epsilon_minus, 
                                                  mu_plus = mu_plus, mu_minus = mu_minus * is_backfire, rho = rho)  


        if save_data:
            np.save(f"../data/is_backfire_{date}/X_{id}_{int(epsilon_plus * 100)}_{int(epsilon_minus * 100)}_{int(mu_plus * 100)}_{int(mu_minus * 100)}_.npy", X)
            np.save(f"../data/is_backfire_{date}/edges_{id}_{int(epsilon_plus * 100)}_{int(epsilon_minus * 100)}_{int(mu_plus * 100)}_{int(mu_minus * 100)}_.npy", edges)
        

    analysis_data = count_interactions(X, edges)
    
    
    out = []
    if method == "svinormal":
        res_svinormal = train_svi(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, guide_family = "normal", rho = rho,
             n_steps = n_steps, intermediate_steps = intermediate_steps, lr = lr, 
             progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_svinormal
    if method == "sviNF":
        res_svinf = train_svi(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, guide_family = "NF", rho = rho,
             n_steps = n_steps, intermediate_steps = intermediate_steps, lr = lr, 
             progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_svinf
    if method == "mcmc":
        res_mcmc = train_mcmc(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, intermediate_samples = intermediate_samples, warmup_samples = warmup_samples,  rho = rho,
                              n_samples = n_samples, num_chains = num_chains, progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_mcmc
    if method == "abc":
        res_abc = train_abc(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, populations_budget = populations_budget, intermediate_populations = intermediate_populations,
                            population_size = population_size, rho = rho, id = id, timeout = timeout)
        out += res_abc
    complete_analysis = [analyse_results(epsilon_plus, epsilon_minus, mu_plus, mu_minus, is_backfire, **res)|analysis_data for res in out]
    return complete_analysis
    # return out, analysis_data










In [73]:
import pandas as pd

In [78]:
pd.concat([pd.DataFrame(pd.read_pickle(file)) for file in glob("../data/*isback*/*svi*pkl")])

Unnamed: 0,id,mse_epsilon,tot_time,n_steps,n_samples,method,n_simulations,mu_plus,mu_minus,is_backfire,...,pos_interactions_plus,pos_interactions_minus,tot_interactions,T,N,edge_per_t,var_X_end,skew_X_end,kurtosis_X_end,bimodality_X_end
0,0_100_128,,34.544479,20000,,svinormal,,0.09,0.15,0,...,144.0,11.0,1270,128,100,10,0.072942,-0.090875,-1.235996,0.041326
0,0_400_128,,33.668364,20000,,svinormal,,0.03,0.03,0,...,353.0,1.0,1270,128,400,10,0.087642,-0.080332,-1.288702,0.028026
0,0_50_2048,,55.951718,20000,,svinormal,,0.03,0.05,0,...,10413.0,163.0,20470,2048,50,10,0.075836,0.241758,-1.941531,0.218525
0,0_400_2048,,58.551247,20000,,svinormal,,0.13,0.09,0,...,2111.0,4128.0,20470,2048,400,10,0.08162,-0.038461,-1.239857,0.027112
0,0_100_512,,36.379876,20000,,svinormal,,0.09,0.17,1,...,4126.0,0.0,5110,512,100,10,0.001417,-0.532276,0.690528,0.01954
0,0_200_128,,39.120497,20000,,svinormal,,0.13,0.15,0,...,550.0,2.0,1270,128,200,10,0.083816,0.077025,-1.346729,0.03467
0,0_100_2048,,57.458236,20000,,svinormal,,0.17,0.09,0,...,8800.0,2463.0,20470,2048,100,10,0.058122,0.451101,-1.796439,0.191722
0,0_400_512,,35.618566,20000,,svinormal,,0.15,0.11,1,...,3038.0,320.0,5110,512,400,10,0.079944,-0.075401,-0.843043,0.0275
0,0_200_512,,37.851959,20000,,svinormal,,0.05,0.19,0,...,498.0,642.0,5110,512,200,10,0.081496,0.078301,-1.234075,0.025048
0,0_50_128,,32.722241,20000,,svinormal,,0.09,0.19,0,...,433.0,358.0,1270,128,50,10,0.09168,0.119093,-1.702834,0.114042


In [2]:
import sys 
sys.path += ["../src"]
import BC_leaders, BC_update
import numpy as np
from tqdm import tqdm
from time import time
import pickle
from glob import glob
from pyABC_ import pyabc
from scipy.special import expit as np_sigmoid

import os
from tempfile import gettempdir
from pyABC_.pyabc.sampler import SingleCoreSampler
from jax.scipy.special import expit as sigmoid
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from numpyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal, AutoBNAFNormal, AutoIAFNormal
from numpyro import distributions
import numpyro
from numpyro.optim import Adam
import jax.random as random
from datetime import timedelta
# numpyro.set_platform("gpu")
from diptest import dipstat
from scipy.stats import kurtosis, skew
import matplotlib.pyplot as plt


In [3]:
def compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = True):
    # edges_iter = (edges_t for edges_t in edges)
    # Xt = jax.lax.stop_gradient(X0.copy())
    Xt = X0.copy()
    X_list = [Xt.copy()]
    
    N = len(Xt)
    
    while True:
        edges_t = next(edges_iter, None)
        if edges_t is None:
            break
        
        u,v,s_plus,s_minus = edges_t.T
        u,v = u.astype(int),v.astype(int)
        diff_X = Xt[u] - Xt[v]

        updates_plus = mu_plus * s_plus * diff_X
        updates_minus = (mu_minus * s_minus * diff_X) * is_backfire
        # print(updates_minus)
        # Xt = Xt.at[v].add(updates_plus - updates_minus).clip(1e-5, 1 - 1e-5)
        Xt[v] += updates_plus - updates_minus
        Xt[v] = np.clip(Xt[v], 1e-5, 1 - 1e-5)
        
        X_list.append(Xt.copy())

    return jnp.stack(X_list)


def initialize_training(X, edges, mu_plus, mu_minus, rho = 32):
    T, N = X.shape    
    u,v,s_plus,s_minus,t = BC_leaders.convert_edges_uvst(edges)
    s_plus, s_minus = jnp.float32(s_plus), jnp.float32(s_minus)

    X0 = np.array(X[0])
    edges_iter = (edge for edge in edges)
    X_bc = compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = False)
    edges_iter = (edge for edge in edges)
    X_back = compute_X_from_X0_params(X0, edges_iter, mu_plus, mu_minus, is_backfire = True)
    u,v,t = u.astype(int), v.astype(int), t.astype(int)

    diff_X_bc = X_bc[t,u] - X_bc[t,v]
    diff_X_back = X_back[t,u] - X_back[t,v]

    return {"u": u, "v": v, "s_plus": s_plus, "s_minus": s_minus, "t": t,
            "N": N, "T": T, "rho": rho,
            "diff_X_bc": diff_X_bc, "diff_X_back": diff_X_back}



In [4]:
def model(data):
    dim = 3
    dist = distributions.Normal(jnp.zeros(dim), jnp.ones(dim)).to_event(1)
    params = numpyro.sample("theta", dist)
    
    theta = params[:2]
    param_backfire = params[2:]
    epsilon_plus, epsilon_minus = sigmoid(theta) /  2 + jnp.array([0.,.5])

    ############
    diff_X_bc,diff_X_back,u,v,s_plus, s_minus,t, rho, N, T = [data[k] for k in ["diff_X_bc", "diff_X_back","u","v",
                                                                      "s_plus", "s_minus","t",
                                                                      "rho", "N", "T"]]
    
    
    backfire_sample = numpyro.sample("backfire", distributions.RelaxedBernoulli(probs = param_backfire, temperature = jnp.array([0.1])).to_event(1))
    is_backfire = backfire_sample[0]
        
    s_plus = jnp.array(s_plus)
    s_minus = jnp.array(s_minus)
 
    diff_X = (1 - is_backfire) * diff_X_bc + is_backfire * diff_X_back
    kappas_plus = BC_leaders.kappa_plus_from_epsilon(epsilon_plus, diff_X, rho, with_jax = True)
    kappas_minus = BC_leaders.kappa_minus_from_epsilon(epsilon_minus, diff_X, rho, with_jax = True)
    kappas_ = jnp.concatenate([kappas_minus, kappas_plus])
    s = jnp.concatenate([s_minus, s_plus])

    with numpyro.plate("data", s.shape[0]):
        numpyro.sample("obs", distributions.Bernoulli(probs = kappas_), obs = s)



In [27]:
def train_svi(X, edges, mu_plus, mu_minus, guide_family = "normal", rho = 32,
              n_steps = 4000, intermediate_steps = None, lr = 0.01, 
              progress_bar = False, id = None, timeout = 3600):
    if intermediate_steps is None:
        intermediate_steps = n_steps
    
    if guide_family == "normal":
        guide = AutoNormal(model)
    if guide_family == "NF":
        guide = AutoBNAFNormal(model, num_flows = 1, hidden_factors = (8,8))
        n_steps = int(n_steps / 2)
        intermediate_steps = int(intermediate_steps / 2)
    
    data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
    optimizer = Adam(step_size = lr)
    svi = SVI(model, guide, optimizer, loss = TraceGraph_ELBO())
    res = []
    last_state = None

    tot_time = 0
    
    for _ in range(int(n_steps / intermediate_steps)):
        
        t0 = time()
        svi_results = svi.run(random.PRNGKey(np.random.randint(low = 0, high = 10**7)), intermediate_steps, data, 
                              init_state = last_state, progress_bar = progress_bar)
        t1 = time()
        tot_time += t1 - t0

        return guide, svi_results

    #     theta_samples = guide.sample_posterior(random.PRNGKey(0), svi_results.params, sample_shape = (200,))
    #     print(theta_samples["theta"][0,:])
    #     param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)
        
    #     res_analysis = {"param_mean": param_mean,
    #                     "param_std": param_std,
    #                     "backfire_mean": backfire_mean,
    #                     "backfire_std": backfire_std,
    #                     "tot_time": tot_time,
    #                     "n_simulations": None,
    #                     "method": "svi" + guide_family,
    #                     "n_steps": intermediate_steps * (_ + 1),
    #                     "n_samples": None,
    #                     "id": id
    #                     }
    #     res.append(res_analysis)

    #     last_state = svi_results.state
    #     if tot_time > timeout:
    #         break

    # return res




In [6]:
def train_mcmc(X, edges, mu_plus, mu_minus, intermediate_samples = None, rho = 32, num_chains = 1,
               warmup_samples = None, n_samples = 400, progress_bar = False, id = None, timeout = 3600):
    if intermediate_samples is None:
        intermediate_samples = n_samples
    if warmup_samples is None:
        warmup_samples = intermediate_samples

    data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
    key = random.PRNGKey(0)
    mcmc = MCMC(NUTS(model), num_warmup = warmup_samples, num_chains = num_chains, 
                num_samples = intermediate_samples, progress_bar = progress_bar)
    res = []
    tot_time = 0
    for _ in range(int(n_samples / intermediate_samples)):
        t0 = time()
        mcmc.run(key, data)
        t1 = time()
        tot_time += t1 - t0

        mcmc.post_warmup_state = mcmc.last_state
        key = mcmc.post_warmup_state.rng_key
        
        mcmc_samples = mcmc.get_samples()
        param_mean, param_std, backfire_mean, backfire_std = analyse_samples(mcmc_samples)
        res.append({"param_mean": param_mean,
                    "param_std": param_std,
                    "backfire_mean": backfire_mean, 
                    "backfire_std": backfire_std,
                    "tot_time": tot_time,
                    "n_simulations": None,
                    "method": "mcmc",
                    "n_steps": None,
                    "n_samples": intermediate_samples * (_ + 1),
                    "id": id})
        if tot_time > timeout:
            break

    return res


def create_summary_statistics(X0, edges_iter, edge_per_t, parameters, mu_plus, mu_minus, rho):
    summary_statistics_list = []
    Xt = X0.copy()
    N = len(Xt)
    
    while True:
        edges_t = next(edges_iter, None)
        if edges_t is None:
            break
        is_backfire = parameters["theta2"]
        epsilon_plus,epsilon_minus = epsilons_from_theta(parameters, dict_theta = True, numpy = True)
        u,v,_,_ = edges_t.T
        u,v = u.astype(int),v.astype(int)
        diff_X = Xt[u] - Xt[v]
        # s_plus = ((np.random.rand(edge_per_t) < np_sigmoid(rho * (epsilon_plus - np.abs(diff_X))))) + 0
        # s_minus = ((np.random.rand(edge_per_t) < np_sigmoid(-rho * (epsilon_minus - np.abs(diff_X))))) + 0
        s_plus =  (np.abs(diff_X) < epsilon_plus) + 0
        s_minus = (np.abs(diff_X) > epsilon_minus) + 0

        updates_plus = mu_plus * s_plus * diff_X 
        updates_minus = mu_minus * s_minus * diff_X * is_backfire
        Xt[v] += updates_plus - updates_minus
        Xt[v] = np.clip(Xt[v], 1e-5, 1 - 1e-5)
            
        summary_statistics_list.append(np.concatenate([u[None,:],v[None,:],s_plus[None,:], s_minus[None,:]])[None,:])

    edges_sim = np.concatenate(summary_statistics_list).transpose(0,2,1)
    return {"s_plus_sum": edges_sim[:,:,-2].sum(axis = 1), 
            "s_minus_sum": edges_sim[:,:,-1].sum(axis = 1)} 

def create_trajectory(X0, edges, parameters, mu_plus, mu_minus, rho):
    X0 = X0.copy()
    edges_iter = (edges_t for edges_t in edges)
    T, edge_per_t, _ = edges.shape
    summary_statistics = create_summary_statistics(X0, edges_iter, edge_per_t, parameters, mu_plus, mu_minus, rho)
    # summary_statistics = create_s_update_X(X0, edges_iter, edge_per_t, parameters, rho, [], [X0[None,:].copy()])
    return summary_statistics

def sim_trajectory_X0_edges(X0, edges, mu_plus, mu_minus, rho):
    return lambda parameters: create_trajectory(X0, edges, parameters, mu_plus, mu_minus, rho)

def train_abc(X, edges, mu_plus, mu_minus, populations_budget = 10, intermediate_populations = None,
              population_size = 200, rho = 32, id = None, timeout = 3600):
    if intermediate_populations is None:
        intermediate_populations = populations_budget
    
    T = len(X)
    res = []
    tot_time = 0
    model_abc = sim_trajectory_X0_edges(X[0], edges, mu_plus, mu_minus, rho)
    prior = pyabc.Distribution(
                theta0=pyabc.RV("norm", 0, 1),
                theta1=pyabc.RV("norm", 0, 1),
                theta2=pyabc.RV("rv_discrete", values = (np.arange(2), 0.5 * np.ones(2))))
    distance = pyabc.PNormDistance(2)
    obs = {"s_plus_sum": edges[:,:,-2].sum(axis = 1), 
           "s_minus_sum": edges[:,:,-1].sum(axis = 1)}
    abc = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
    db = "sqlite:///" + os.path.join(gettempdir(), f"{id}_is_backfire_test.db")
    history = abc.new(db, obs)
    run_id = history.id
    for _ in range(int(populations_budget / intermediate_populations)):
        abc_continued = pyabc.ABCSMC(model_abc, prior, distance, population_size = population_size)#, sampler = SingleCoreSampler())
        abc_continued.load(db, run_id)
        t0 = time()
        history = abc_continued.run(max_nr_populations = intermediate_populations,
                                    # minimum_epsilon = T,
                                    minimum_epsilon = 5 * (T ** (1/2)),
                                    max_walltime = timedelta(hours = 3))
        t1 = time()
        tot_time += (t1 - t0)
        theta_samples = jnp.array(history.get_distribution()[0])

        param_mean, param_std, backfire_mean, backfire_std = analyse_samples({"theta": theta_samples[:,:2], 
                                                                              "backfire": theta_samples[:,2]})
        res_analysis = {"param_mean": param_mean,
                        "param_std": param_std,
                        "backfire_mean": backfire_mean, 
                        "backfire_std": backfire_std,
                        "tot_time": tot_time,
                        "n_simulations": history.total_nr_simulations,
                        "method": "abc",
                        "n_steps": None,
                        "n_samples": None,
                        "id": id
                        }
        res.append(res_analysis)
        if tot_time > timeout:
            break
    return res

def count_interactions(X, edges):
    T, N = X.shape
    _,edge_per_t,_ = edges.shape
    
    pos_interactions_plus, pos_interactions_minus = edges[:,:,2].sum(), edges[:,:,3].sum()
    tot_interactions = (T - 1) * edge_per_t
    

    return {"pos_interactions_plus":pos_interactions_plus, 
            "pos_interactions_minus":pos_interactions_minus, 
            "tot_interactions":tot_interactions,
            "T": T, "N": N, "edge_per_t": edge_per_t,
            "var_X_end": X[-1].var(),
            "skew_X_end": skew(X[-1]),
            "kurtosis_X_end": kurtosis(X[-1]),
            "bimodality_X_end": dipstat(X[-1]),
            }

########## all ##############
def epsilons_from_theta(parameters, dict_theta = False, numpy = False):
    
    sigmoid_fn = np_sigmoid if numpy else sigmoid
    if dict_theta:
        epsilon_plus = sigmoid_fn(parameters["theta0"]) / 2
        epsilon_minus = sigmoid_fn(parameters["theta1"]) / 2 + .5
        
        return epsilon_plus,epsilon_minus
    elif len(parameters.shape) == 1:
        epsilon_plus,epsilon_minus,is_backfire = sigmoid(parameters) / jnp.array([2, 2, 1]) + jnp.array([0.,.5, 0.])
        return epsilon_plus,epsilon_minus,is_backfire
    else:
        trans_theta = (sigmoid(parameters) / 2 + jnp.array([0.,.5])).T
        return trans_theta



In [7]:
def analyse_samples(samples):
        param_samples = epsilons_from_theta(samples["theta"][:,:2], dict_theta = False, numpy = False).T
        param_mean, param_std = param_samples.mean(axis = 0), param_samples.std(axis = 0)
        backfire_mean = samples["backfire"].mean(axis = 0)
        backfire_std = samples["backfire"].std(axis = 0)
        
        return param_mean, param_std, backfire_mean, backfire_std

def analyse_results(epsilon_plus, epsilon_minus, mu_plus, mu_minus, is_backfire, backfire_mean, backfire_std,
                    param_mean, param_std, n_samples, n_steps, n_simulations, id,
                    tot_time, method):
    params = np.array([epsilon_plus, epsilon_minus])

    param_names = ["epsilon_plus", "epsilon_minus"]
    
    out = {
            "id": id,
            "mse_epsilon": ((params[:2] - param_mean[:2])**2).mean().item(), 
            "tot_time": tot_time,
            "n_steps": n_steps, 
            "n_samples": n_samples,
            "method": method,
            "n_simulations": n_simulations,
            "mu_plus": mu_plus, 
            "mu_minus": mu_minus,
            "is_backfire": is_backfire,
            "round_backfire": backfire_mean.round().item(),
            "backfire_mean": backfire_mean.item(), 
            "backfire_std": backfire_std.item(),
            }
    
    out.update({u + "_error": np.abs(params[k] - param_mean[k]) for k, u in enumerate(param_names)})
    out.update({u + "_mean": param_mean[k].item() for k, u in enumerate(param_names)})
    out.update({u + "_std": param_std[k].item() for k, u in enumerate(param_names)})
    out.update({u + "_real": params[k].item() for k, u in enumerate(param_names)})

    return out
        

def save_pickle(out, path):
    if path is not None:
        with open(path, "wb") as f:
            pickle.dump(out, f)

def complete_experiment(N, T, edge_per_t, rho = 32,
                        method = "svinormal",
                        epsilon_plus = None, epsilon_minus = None, mu_plus = None, mu_minus = None,
                        n_steps = 1000, n_samples = 100, populations_budget = 10, num_chains = 1,
                        intermediate_steps = None, intermediate_samples = None, warmup_samples = None, 
                        intermediate_populations = None, population_size = 200, 
                        lr = 0.01, progress_bar = False, timeout = 25000, id = None, date = None, save_data = True
                        ):
    if len(glob(f"../data/isbackfire_{date}/X_{id}*")) > 0:
        X_file = glob(f"../data/isbackfire_{date}/X_{id}*")[0]
        edges_file = glob(f"../data/isbackfire_{date}/edges_{id}*")[0]
        X = np.load(X_file)
        edges = np.load(edges_file)
        
        _,_,epsilon_plus,epsilon_minus, mu_plus, mu_minus, is_backfire = [int(u) for u in X_file.split("/")[-1].split("_")[2:-1]]
        epsilon_plus, epsilon_minus, mu_plus, mu_minus = np.array([epsilon_plus, epsilon_minus, mu_plus, mu_minus]) / 100
    else:
        if epsilon_plus is None:
            epsilon_plus = round(np.random.randint(5) * 0.1 + 0.05, 4)
            epsilon_minus = round(np.random.randint(5) * 0.1 + 0.55, 4)
            mu_plus = round(np.random.randint(10) * 0.02 + 0.01, 4)
            mu_minus = round(np.random.randint(10) * 0.02 + 0.01, 4)
            is_backfire = np.random.randint(2)
    
        X, edges = BC_update.simulate_trajectory(N = N, T = T, edge_per_t = edge_per_t, 
                                                  epsilon_plus = epsilon_plus, epsilon_minus = epsilon_minus, 
                                                  mu_plus = mu_plus, mu_minus = mu_minus * is_backfire, rho = rho)  


        if save_data:
            np.save(f"../data/isbackfire_{date}/X_{id}_{int(epsilon_plus * 100)}_{int(epsilon_minus * 100)}_{int(mu_plus * 100)}_{int(mu_minus * 100)}_{int(is_backfire)}_.npy", X)
            np.save(f"../data/isbackfire_{date}/edges_{id}_{int(epsilon_plus * 100)}_{int(epsilon_minus * 100)}_{int(mu_plus * 100)}_{int(mu_minus * 100)}_{int(is_backfire)}_.npy", edges)
        

    analysis_data = count_interactions(X, edges)
    
    
    out = []
    if method == "svinormal":
        res_svinormal = train_svi(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, guide_family = "normal", rho = rho,
             n_steps = n_steps, intermediate_steps = intermediate_steps, lr = lr, 
             progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_svinormal
    if method == "sviNF":
        res_svinf = train_svi(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, guide_family = "NF", rho = rho,
             n_steps = n_steps, intermediate_steps = intermediate_steps, lr = lr, 
             progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_svinf
    if method == "mcmc":
        res_mcmc = train_mcmc(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, intermediate_samples = intermediate_samples, warmup_samples = warmup_samples,  rho = rho,
                              n_samples = n_samples, num_chains = num_chains, progress_bar = progress_bar, id = id, timeout = timeout)
        out += res_mcmc
    if method == "abc":
        res_abc = train_abc(X, edges, mu_plus = mu_plus, mu_minus = mu_minus, populations_budget = populations_budget, intermediate_populations = intermediate_populations,
                            population_size = population_size, rho = rho, id = id, timeout = timeout)
        out += res_abc
    complete_analysis = [analyse_results(epsilon_plus, epsilon_minus, mu_plus, mu_minus, is_backfire, **res)|analysis_data for res in out]
    return complete_analysis
    # return out, analysis_data

In [9]:
res = complete_experiment(N = 50, T = 128, edge_per_t = 10, rho = 32,
                        method = "svinormal",
                        epsilon_plus = None, epsilon_minus = None, mu_plus = None, mu_minus = None,
                        n_steps = 1000, n_samples = 100, populations_budget = 10, num_chains = 1,
                        intermediate_steps = None, intermediate_samples = None, warmup_samples = None, 
                        intermediate_populations = None, population_size = 200, 
                        lr = 0.01, progress_bar = False, timeout = 25000, id = None, date = None, save_data = False)

In [10]:
res

[{'id': None,
  'mse_epsilon': nan,
  'tot_time': 28.872684717178345,
  'n_steps': 1000,
  'n_samples': None,
  'method': 'svinormal',
  'n_simulations': None,
  'mu_plus': 0.09,
  'mu_minus': 0.15,
  'is_backfire': 0,
  'round_backfire': nan,
  'backfire_mean': nan,
  'backfire_std': nan,
  'epsilon_plus_error': nan,
  'epsilon_minus_error': nan,
  'epsilon_plus_mean': nan,
  'epsilon_minus_mean': nan,
  'epsilon_plus_std': nan,
  'epsilon_minus_std': nan,
  'epsilon_plus_real': 0.35,
  'epsilon_minus_real': 0.85,
  'pos_interactions_plus': 937.0,
  'pos_interactions_minus': 9.0,
  'tot_interactions': 1270,
  'T': 128,
  'N': 50,
  'edge_per_t': 10,
  'var_X_end': 0.032890177281449426,
  'skew_X_end': 1.6981624694014712,
  'kurtosis_X_end': 1.974350854754495,
  'bimodality_X_end': 0.038635571821924}]

In [21]:
res_svinf = train_svi(X, edges, mu_plus, mu_minus, guide_family = "normal", rho = 32,
             n_steps = 1000, intermediate_steps = None, lr = 0.01,
             progress_bar = True)

100%|██████████| 1000/1000 [00:13<00:00, 74.35it/s, init loss: 201.2233, avg. loss [951-1000]: 41.4307]


In [13]:
res_svinf

(<numpyro.infer.autoguide.AutoNormal at 0x7f7a767e77f0>,
 SVIRunResult(params={'backfire_auto_loc': Array([-3.1589024], dtype=float32), 'backfire_auto_scale': Array([3.2499418], dtype=float32), 'theta_auto_loc': Array([-0.3719332 ,  0.69258356,  0.40357932], dtype=float32), 'theta_auto_scale': Array([0.06966158, 0.14158776, 0.2262375 ], dtype=float32)}, state=SVIState(optim_state=(Array(1000, dtype=int32, weak_type=True), OptimizerState(packed_state=([Array([-3.1589024], dtype=float32), Array([0.07490384], dtype=float32), Array([0.15609263], dtype=float32)], [Array([3.2103934], dtype=float32), Array([-0.2429873], dtype=float32), Array([0.17901044], dtype=float32)], [Array([-0.3719332 ,  0.69258356,  0.40357932], dtype=float32), Array([-4.1514473,  2.360871 ,  0.8149014], dtype=float32), Array([449.30005 ,  33.678112, 843.2271  ], dtype=float32)], [Array([-2.6290734, -1.8832065, -1.3709195], dtype=float32), Array([ 0.31187022, -0.24694623, -0.00954531], dtype=float32), Array([ 5.408944 

In [32]:
res_train = train_svi(X, edges, mu_plus, mu_minus, guide_family = "normal", rho = 32,
              n_steps = 1000, intermediate_steps = None, lr = 0.01)#, 
            #   progress_bar = True)

In [22]:
analyse_samples(res_svinf[0].sample_posterior(random.PRNGKey(10), res_svinf[1].params, sample_shape = (200,)))

(Array([0.2032994, 0.8300455], dtype=float32),
 Array([0.00855401, 0.01788632], dtype=float32),
 Array([0.17058524], dtype=float32),
 Array([0.27980375], dtype=float32))

In [33]:
analyse_samples(res_train[0].sample_posterior(random.PRNGKey(10), res_train[1].params, sample_shape = (200,)))

(Array([nan, nan], dtype=float32),
 Array([nan, nan], dtype=float32),
 Array([nan], dtype=float32),
 Array([nan], dtype=float32))

In [30]:
res_train

[{'param_mean': Array([nan, nan], dtype=float32),
  'param_std': Array([nan, nan], dtype=float32),
  'backfire_mean': Array([nan], dtype=float32),
  'backfire_std': Array([nan], dtype=float32),
  'tot_time': 8.956254482269287,
  'n_simulations': None,
  'method': 'svinormal',
  'n_steps': 4000,
  'n_samples': None,
  'id': None}]

In [31]:
guide = AutoNormal(model)
data = initialize_training(jnp.array(X), jnp.array(edges), mu_plus, mu_minus, rho = rho)
optimizer = Adam(step_size = 0.01)
svi = SVI(model, guide, optimizer, loss = TraceGraph_ELBO())

svi_results = svi.run(random.PRNGKey(0), 1000, data)
        
theta_samples = guide.sample_posterior(random.PRNGKey(0), svi_results.params, sample_shape = (200,))
    

100%|██████████| 1000/1000 [00:06<00:00, 149.38it/s, init loss: 206.3122, avg. loss [951-1000]: 38.6567]


In [52]:
svi_results.params

{'backfire_auto_loc': Array([-3.252989], dtype=float32),
 'backfire_auto_scale': Array([4.0210066], dtype=float32),
 'theta_auto_loc': Array([-0.35893068,  0.5886629 ,  0.40260231], dtype=float32),
 'theta_auto_scale': Array([0.07025062, 0.16956775, 0.22583275], dtype=float32)}

In [32]:
param_mean, param_std, backfire_mean, backfire_std = analyse_samples(theta_samples)

In [33]:
{"param_mean": param_mean,
                        "param_std": param_std,
                        "backfire_mean": backfire_mean,
                        "backfire_std": backfire_std,
                        "tot_time": 0,
                        "n_simulations": None,
                        "method": "svi",
                        "n_steps": 0,
                        "n_samples": None,
                        "id": 2
                        }

{'param_mean': Array([0.2054908 , 0.82061934], dtype=float32),
 'param_std': Array([0.00793032, 0.01811719], dtype=float32),
 'backfire_mean': Array([0.22971423], dtype=float32),
 'backfire_std': Array([0.32948285], dtype=float32),
 'tot_time': 0,
 'n_simulations': None,
 'method': 'svi',
 'n_steps': 0,
 'n_samples': None,
 'id': 2}

In [39]:
pd.concat([pd.DataFrame(pd.read_pickle(file)) for file in glob("../data/*back*/**pkl")])["mse_epsilon"].isna().sum()

0

In [35]:
import pandas as pd