# Contextualized Bayesian Networks

For more details, please see the [NOTMAD preprint](https://arxiv.org/abs/2111.01104).

Here, we want to ask the following questions:
- how the performance of NOTMAD compares against population networks and cluster network methods.
- how the performance of NOTMAD changes with loss type ("NOTEARS", "DAGMA", or "poly")
- how the performance of NOTMAD changes with number of factors ("num_factors")
- how the performance changes with number of samples (n) and number of features (p). 

TODO:
- Automated generation of W (clusters, linear function of C)

Possibly, vary:
- signal-to-noise ratio of the context-to-network parameter relationship
- encoder type ("ngam", "mlp")
- regularization of NOTMAD:
    NOTEARS loss has parameters:
    
        alpha (float)

        rho (float)

        use_dynamic_alpha_rho (Boolean)

    DAGMA loss has parameters:
    
        alpha (strength, default 1e0)

        s (max spectral radius, default 1)

Report results in terms of:
- MSE of X predictions (measure_mses function)
- recovery of W

In [None]:
import numpy as np
import random
import os
import matplotlib.pyplot as plt
%matplotlib inline

import sys

import contextualized
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

from contextualized.dags.graph_utils import simulate_linear_sem, is_dag
from contextualized.baselines import (
    BayesianNetwork,
    GroupedNetworks,
)
from contextualized.easy import ContextualizedBayesianNetworks
import pickle as pkl
import os
import numpy

def measure_mses(betas, X, individual_preds=False):
    """
    Measure mean-squared errors.
    """
    mses = np.zeros((len(betas), len(X)))  # n_bootstraps x n_samples
    for bootstrap in range(len(betas)):
        for target_feat in range(X.shape[-1]):
            # betas are n_boostraps x n_samples x n_features x n_features
            # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :])
            preds = np.array(
                [
                    X[sample].dot(betas[bootstrap, sample, :, target_feat])  # + mus[bootstrap, j, i]
                    for sample in range(len(X))
                ]
            )
            residuals = X[:, target_feat] - preds
            mses[bootstrap, :] += residuals**2 / (X.shape[-1])
    if not individual_preds:
        mses = np.mean(mses, axis=0)
    return mses

def measure_recovery(W_true, W):
    # Assumes W has a prefix dimension for bootstraps
    recovery_errs = []
    for bootstrap in range(len(W)):
        for sample in range(len(W[bootstrap])):
            recovery_errs.append(np.linalg.norm(W_true[sample] - W[bootstrap][sample], ord=2))
    return np.mean([recovery_errs])

def make_dag(p, n_nonempty=4):
    # # n_nonempty = number of non-empty cells in W
    # coeffs = [np.random.choice([np.random.uniform(-1,-0.5), np.random.uniform(0.5,1)]) for i in range(n_nonempty)]
    # c * coeffs[assign_c_idx.index((i,j))]
    #create upper triangular
    tuples = [(i, j) for i in range(p) for j in range(p) if i < j]
    
    #randomly flip to account for monodirectinality
    tuples = [(t[1],t[0]) if np.random.uniform(0,1) > 0.5 else t for t in tuples]

    #select n_nonempty
    assign_c_idx = random.sample(tuples, n_nonempty)

    w = np.zeros((p, p))

    for i in range(p):
        for j in range(p):
            if (i,j) in assign_c_idx:
                w[i, j] = 1
                # w[0, 1] = 1
                # w[2, 1] = 1
                # w[3, 1] = 1
                # w[3, 2] = 1
    if is_dag(w):
        return w
    else:
        return make_dag(p, n_nonempty)

def generate_WC(p, n, n_clusters = 5, mode='linear'):
    '''Generate W and C for a given p, n, and mode.'''

    W = np.zeros((n, p, p))
    
    if mode == 'cluster':
        
        C = np.zeros((n, 1))
        cs = [np.random.normal(0, 1) for i in range(n_clusters)]
        centroid_ws = []
        for i in range(n_clusters):
            centroid_ws.append(make_dag(p))

        for i in range(n):
            c_idx = np.random.choice(len(centroid_ws))
            masked_std = np.random.normal(0, 0.2, size=(p, p)) * (centroid_ws[c_idx] != 0)            
            W[i] = centroid_ws[c_idx] + masked_std
            C[i,0] = c_idx

    elif mode == 'linear':
        dag = make_dag(p)

        C = np.random.normal(0, 1, size=(n, 1))
        coeffs = [np.random.choice([np.random.uniform(-1,-0.5), np.random.uniform(0.5,1)]) for i in range(int(dag.sum()))]
        
        for i in range(n):
            #Each non-empty = c * coeffs[assign_c_idx.index((i,j))]
            assign_c_idx = [(i,j) for i in range(p) for j in range(p) if dag[i,j] == 1]
            for j in range(len(assign_c_idx)):
                W[i, assign_c_idx[j][0], assign_c_idx[j][1]] = C[i,0] * coeffs[j]

    else:
        assert "Error, mode must be 'cluster' or 'linear'"
    
    return W, C

In [None]:
from ray import tune
import logging
import functools
from ray.tune.schedulers import HyperBandScheduler

import ray
logging.getLogger("ray.tune").setLevel(logging.FATAL)
ray.init(
    num_cpus=2,
    num_gpus=1,
    object_store_memory=2*1024*1024*1024,
    log_to_driver=False,
    logging_level=logging.FATAL,
)



In [None]:
# TODO: Sweep over graph types.
# TODO: Sweep over factors
# TODO: Sweep over signal-to-noise ratio.
import pickle as pkl
def get_data():
    n_data_gens = 2
    ns = [100, 1000]
    ps = [4, 10]
    graph_types = ["gauss", "exp", "gumbel", "uniform", "logistic", "poisson"]
    W_mode = ['cluster', 'linear']
    n_clusters = 5
    data = []
    data_line = ""
    with open("data2/data_info.csv", 'w') as out_file:
        for n in ns:
            for p in ps:
                # for noise in noise_scales:
                    # for graph_type in graph_types:
                for W_mode in W_mode:
                    for data_gen in range(n_data_gens):
                        data_line = f"{n}, {p}, {W_mode}, {data_gen}"
                        W, C = generate_WC(p, n, n_clusters, mode='cluster')
                        X = np.array([simulate_linear_sem(w, n_samples=1, sem_type='uniform', noise_scale=0.5)[0] for w in W])
                        data.append((C, X, W))
                        print(data_line, file=out_file)
    #save data
    # pkl.dump(data, open("data2/data2.pkl", 'wb'))
    return data

# datas = get_data()
datas = pkl.load(open("data2/data2.pkl", 'rb'))
print(len(datas))


In [6]:
def train_and_evaluate(config, data, strat_config, checkpoint_dir='checkpoints/'):
    C,X,W = data
    
    config.update(strat_config)

    n = W.shape[0]
    p = W.shape[-1]
    k = config['ks']
    n_fit_iters = 1

    header = "n, p, k, data_gen, fit_iter, "
    header += "recovery_pop_train, recovery_pop_test, mse_pop_train, mse_pop_test, "
    header += "recovery_cluster_train, recovery_cluster_test, mse_cluster_train, mse_cluster_test, "
    header += "recovery_notmad_notears_train, recovery_notmad_notears_test, mse_notmad_notears_train, mse_notmad_notears_test, "
    header += "recovery_notmad_dagma_train, recovery_notmad_dagma_test, mse_notmad_dagma_train, mse_notmad_dagma_test, "
    header += "recovery_notmad_poly_train, recovery_notmad_poly_test, mse_notmad_poly_train, mse_notmad_poly_test"

    C_train, C_test, X_train, X_test, W_train, W_test = train_test_split(C, X, W, test_size=0.3)

    recovery_train = lambda W_pred: measure_recovery(W_train, W_pred)
    recovery_test = lambda W_pred: measure_recovery(W_test, W_pred)
    mse_train = lambda W_pred: np.mean(measure_mses(W_pred, X_train))
    mse_test = lambda W_pred: np.mean(measure_mses(W_pred, X_test))
    results_string = lambda W_pred_train, W_pred_test: f", {recovery_train(W_pred_train)}, {recovery_test(W_pred_test)}, {mse_train(W_pred_train)}, {mse_test(W_pred_test)}"
    
    test_mses = []
    test_recoveries = []
    train_mses = []
    train_recoveries = []
    # results_line = ""
    for fit_iter in range(n_fit_iters):

        if config['loss_type'] == 'NOTEARS':
            loss_kwargs = {
                'archetype_alpha': 0.0,
                'archetype_rho': 0.0,
                'sample_specific_alpha': config['alpha'],
                'sample_specific_rho': config['rho'],
            }
        elif config['loss_type'] == 'DAGMA':
            loss_kwargs = {
                'archetype_alpha': 0.0,
                'archetype_s': config['s'],
                'sample_specific_alpha': config['alpha'],
                'sample_specific_s': config['s']
            }
        elif config['loss_type'] == 'poly':
            loss_kwargs = {}

        cbn = ContextualizedBayesianNetworks(
            encoder_type=config['encoder_types'],
            num_archetypes=k,
            num_factors=config['n_factors'],
            archetype_dag_loss_type=config['loss_type'],
            sample_specific_dag_loss_type=config['loss_type'],
            **loss_kwargs,
            n_bootstraps=3,
            learning_rate=1e-3,
        )
        
        cbn.fit(C_train, X_train, max_epochs=100)
        cbn_preds_train = cbn.predict_networks(C_train, individual_preds=True)
        cbn_preds_test = cbn.predict_networks(C_test, individual_preds=True)
        
        test_mses.append(mse_test(cbn_preds_test))
        test_recoveries.append(recovery_test(cbn_preds_test))
        train_mses.append(mse_train(cbn_preds_train))
        train_recoveries.append(recovery_train(cbn_preds_train))
    
    train_mse = np.mean(train_mses)
    train_recovery = np.mean(train_recoveries)
    test_mse = np.mean(test_mses)
    test_recovery = np.mean(test_recoveries)
    
    tune.report(train_mse=train_mse, train_recovery=train_recovery,test_mse=test_mse, test_recovery=test_recovery)


#hyperparameter groups
losses = ["NOTEARS", "DAGMA", "poly"]
encoder_types = ['ngam', 'mlp']
ks = [4, 8, 16]

stratify_by = {
    'loss_type': losses,
    'encoder_types': encoder_types,
    'ks':ks
}

# [3 x (500) num_models / samples] x [(2 + 3 + 2) hps] x [20 datasets] = 210,000 runs
# / 3600 = ~60 hours Training time

n_runs = 3

for d in range(len(datas)):
    for strat in stratify_by.keys():
        config = {
            "loss_type": tune.choice(losses),
            "encoder_types": tune.choice(encoder_types),
            'use_dynamic_alpha_rho': tune.choice([True, False]),
            "ks": tune.choice(ks),
            "n_factors": tune.sample_from(lambda spec: int(np.random.randint(0,spec.config.ks))),
            "alpha": tune.uniform(0.001, 0.1),
            "rho": tune.uniform(0.001, 0.1),
            "s": tune.uniform(0.001, 5),
        }

        del config[strat]
        if strat == 'ks':
            del config['n_factors']

        from ray.tune.schedulers import ASHAScheduler
        from ray.tune import CLIReporter
        from ray.air import RunConfig

        for instance in stratify_by[strat]:
            if strat == 'ks':
                config['n_factors'] = tune.sample_from(lambda spec: int(np.random.randint(0,instance)))
            
            for i in range (n_runs):
                
                if not os.path.exists('data3/'): os.mkdir('data3/')
                if os.path.exists(f'data3/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle'): 
                    with open(f'data3/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle', 'rb') as handle:
                            r = pkl.load(handle)
                            results = r.get_best_result()
                            #redo broken runs
                            if not len(results.metrics.keys()) == 1: continue
                            print(f'redoing broken run: tune_analysis_d{d}_s{strat}_ins{instance}_i{i}')

                reporter = CLIReporter(max_progress_rows=10)
                reporter.add_metric_column("test_mse")
                reporter.add_metric_column("test_recovery")
                reporter.add_metric_column("train_mse")
                reporter.add_metric_column("train_recovery")

                tuner = tune.Tuner(
                    functools.partial(train_and_evaluate, data=(datas[d]), strat_config={strat: instance}),
                    param_space=config,
                    run_config=RunConfig(
                        progress_reporter=reporter
                    ),
                    tune_config=tune.TuneConfig(
                        num_samples=500,
                        scheduler=ASHAScheduler(),
                        metric="test_mse",
                        mode='min',
                    ),
                )
                results = tuner.fit()
                
                with open(f'data3/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle', 'wb') as handle:
                    pkl.dump(results, handle)


== Status ==
Current time: 2023-03-31 14:06:06 (running for 00:15:37.14)
Memory usage on this node: 51.8/64.0 GiB 
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 64.000: None | Iter 16.000: None | Iter 4.000: None | Iter 1.000: None
Resources requested: 2.0/2 CPUs, 0/1 GPUs, 0.0/38.21 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/wtlo/ray_results/train_and_evaluate_2023-03-31_13-50-29
Number of trials: 300/500 (282 ERROR, 16 PENDING, 2 RUNNING)
+--------------------------------+----------+-------+-----------+-----------------+------+-------------+-----------+---------+------------------------+
| Trial name                     | status   | loc   |     alpha | encoder_types   |   ks |   n_factors |       rho |       s | use_dynamic_alpha_rh   |
|                                |          |       |           |                 |      |             |           |         | o                      |
|--------------------------------+----------+-------+-----------+-----------------+---

In [None]:
with open("results.csv", 'w') as out_file:
    print(header, file=out_file)
    for n in ns:
        for p in ps:
            for k in ks:
                for data_gen in range(n_data_gens):
                    # Generate data.
                    C = np.random.normal(0, 1, size=(n, 1))
                    W = np.zeros((p, p, n, 1))

                    # TODO: Automate generation of W -- clusters?
                    W[0, 1] = C - 2
                    W[2, 1] = C**2
                    W[3, 1] = C**3
                    W[3, 2] = C
                    
                    W = np.squeeze(W)
                    W = np.transpose(W, (2, 0, 1))
                    X = np.array([simulate_linear_sem(w, n_samples=1, sem_type="uniform", noise_scale=0.0)[0] for w in W])

                    # print(X.shape)
                    # print(X)


                    # fdsafd
                    C_train, C_test, X_train, X_test, W_train, W_test = train_test_split(C, X, W, test_size=0.3)
                    
                    recovery_train = lambda W_pred: measure_recovery(W_train, W_pred)
                    recovery_test = lambda W_pred: measure_recovery(W_test, W_pred)
                    mse_train = lambda W_pred: np.mean(measure_mses(W_pred, X_train))
                    mse_test = lambda W_pred: np.mean(measure_mses(W_pred, X_test))
                    results_string = lambda W_pred_train, W_pred_test: f", {recovery_train(W_pred_train)}, {recovery_test(W_pred_test)}, {mse_train(W_pred_train)}, {mse_test(W_pred_test)}"

                    for fit_iter in range(n_fit_iters):
                        results_line = f"{n}, {p}, {k}, {data_gen}, {fit_iter}"
                        
                        dag = BayesianNetwork().fit(X_train, max_epochs=100)
                        pop_preds_train = np.expand_dims(dag.predict(len(X_train)), 0)
                        pop_preds_test  = np.expand_dims(dag.predict(len(X_test)), 0)
                        results_line += results_string(pop_preds_train, pop_preds_test)

                        km = KMeans(n_clusters=4)
                        km.fit(C)
                        cluster_dag = GroupedNetworks(BayesianNetwork).fit(X_train, km.predict(C_train))
                        cluster_preds_train = np.expand_dims(cluster_dag.predict(km.predict(C_train)), 0)
                        cluster_preds_test = np.expand_dims(cluster_dag.predict(km.predict(C_test)), 0)
                        results_line += results_string(cluster_preds_train, cluster_preds_test)

                        for loss in losses:
                            cbn = ContextualizedBayesianNetworks(
                                encoder_type='ngam',
                                num_archetypes=k,
                                num_factors=-1,
                                archetype_alpha=0.,
                                archetype_rho=0.,
                                sample_specific_alpha=1e-1,
                                sample_specific_rho=1e-2,
                                archetype_dag_loss_type=loss,
                                sample_specific_dag_loss_type=loss,
                                n_bootstraps=3,
                                learning_rate=1e-3)
                            cbn.fit(C_train, X_train, max_epochs=100)
                            cbn_preds_train = cbn.predict_networks(C_train, individual_preds=True)
                            cbn_preds_test = cbn.predict_networks(C_test, individual_preds=True)
                            results_line += results_string(cbn_preds_train, cbn_preds_test)
                        
                        print(results_line, file=out_file)

In [None]:
print(results_df.columns)
ns = list(set(results_df['n'].values))
ps = list(set(results_df[' p'].values))
ks = list(set(results_df[' k'].values))
data_gens = list(set(results_df[' data_gen'].values))
fit_iters = list(set(results_df[' fit_iter'].values))

# Plot error by n
for p in ps:
    fig = plt.figure()
    pop_errs_to_plot = np.zeros((len(ns), len(data_gens)*len(fit_iters)*len(ks)))
    cluster_errs_to_plot = np.zeros((len(ns), len(data_gens)*len(fit_iters)*len(ks)))
    notmad_notears_errs_to_plot = {k: np.zeros((len(ns), len(data_gens)*len(fit_iters))) for k in ks}
    for i, n in enumerate(ns):
        for data_gen in data_gens:
            plot_idxs_start = data_gen*len(fit_iters)*len(ks)
            plot_idxs_end = (data_gen+1)*len(fit_iters)*len(ks)
            
            idxs = np.logical_and(
                results_df['n'] == n,
                np.logical_and(
                    results_df[' p'] == p,
                    results_df[' data_gen'] == data_gen)
            )
            pop_errs = results_df[' recovery_pop_test'].loc[idxs] # 3 fit_iters, 3 ks
            pop_err = np.mean(pop_errs)
            pop_errs_to_plot[i, plot_idxs_start:plot_idxs_end] = 1
            
            cluster_errs = results_df[' recovery_cluster_test'].loc[idxs]
            cluster_errs /= pop_err
            cluster_errs_to_plot[i, plot_idxs_start:plot_idxs_end] = cluster_errs
            
            for k in ks:
                notmad_notears_errs = results_df[' recovery_notmad_notears_test'].loc[np.logical_and(
                    idxs,
                    results_df[' k'] == k
                )]
                notmad_notears_errs /= pop_err
                notmad_notears_errs_to_plot[k][i, data_gen*len(fit_iters):(data_gen+1)*len(fit_iters)] = notmad_notears_errs

    #print(pop_errs_to_plot.shape)
    plt.errorbar(ns,
                 np.mean(pop_errs_to_plot, axis=1),
                 yerr=2*np.std(pop_errs_to_plot, axis=1), label="Population")
    plt.errorbar(ns,
                 np.mean(cluster_errs_to_plot, axis=1),
                 yerr=2*np.std(cluster_errs_to_plot, axis=1), label="Cluster")
    for k in ks:
        plt.errorbar(ns,
                 np.mean(notmad_notears_errs_to_plot[k], axis=1),
                 yerr=2*np.std(notmad_notears_errs_to_plot[k], axis=1), label=f"NOTMAD-NOTEARS-{k}")
    plt.legend()
    plt.title(p)
    plt.xlabel("N")

In [None]:
# TODO: How to compare across data gen runs? Probably normalize by population error.
def plot_mses(mses, title):
    plt.figure()
    plt.bar(
        [0, 1, 2],
        [np.mean(mses["pop"]), np.mean(mses["cluster"]), np.mean(mses["notmad"])],
        yerr=[2*np.std(mses["pop"]), 2*np.std(mses["cluster"]), 2*np.std(mses["notmad"])]
    )
    plt.xticks([0, 1, 2], ["Population", "Cluster", "NOTMAD"], rotation=45, fontsize=16)
    plt.yticks(fontsize=16)
    plt.title(title, fontsize=16)
    plt.show()

for n in ns:
    for p in ps:
        for data_gen in range(n_data_gens):
            mses_train = results[(n, p, data_gen, "train")]
            mses_test = results[(n, p, data_gen, "test")]

            plot_mses(mses_train, f"n={n}, p={p}, Train")
            plot_mses(mses_test, f"n={n}, p={p}, Test")