# Contextualized Bayesian Networks

For more details, please see the [NOTMAD preprint](https://arxiv.org/abs/2111.01104).

Here, we want to ask the following questions:
- how the performance of NOTMAD compares against population networks and cluster network methods.
- how the performance of NOTMAD changes with loss type ("NOTEARS", "DAGMA", or "poly")
- how the performance of NOTMAD changes with number of factors ("num_factors")
- how the performance changes with number of samples (n) and number of features (p). 

TODO:
- Automated generation of W (clusters, linear function of C)

Possibly, vary:
- signal-to-noise ratio of the context-to-network parameter relationship
- encoder type ("ngam", "mlp")
- regularization of NOTMAD:
    NOTEARS loss has parameters:
    
        alpha (float)

        rho (float)

        use_dynamic_alpha_rho (Boolean)

    DAGMA loss has parameters:
    
        alpha (strength, default 1e0)

        s (max spectral radius, default 1)

Report results in terms of:
- MSE of X predictions (measure_mses function)
- recovery of W

In [None]:
import numpy as np
import random
import os
import matplotlib.pyplot as plt
%matplotlib inline


#auto update
%load_ext autoreload
%autoreload 2
import sys

import contextualized
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

from contextualized.dags.graph_utils import simulate_linear_sem, is_dag
from contextualized.baselines import (
    BayesianNetwork,
    GroupedNetworks,
)
from contextualized.easy import ContextualizedBayesianNetworks
import pickle as pkl
import os
import numpy

from ray import tune
import logging
import functools
from ray.tune.schedulers import HyperBandScheduler

import ray
logging.getLogger("ray.tune").setLevel(logging.FATAL)
ray.init(
    num_cpus=2,
    num_gpus=1,
    object_store_memory=2*1024*1024*1024,
    log_to_driver=False,
    logging_level=logging.FATAL,
)

In [None]:
def measure_mses(betas, X, individual_preds=False):
    """
    Measure mean-squared errors.
    """
    mses = np.zeros((len(betas), len(X)))  # n_bootstraps x n_samples
    for bootstrap in range(len(betas)):
        for target_feat in range(X.shape[-1]):
            # betas are n_boostraps x n_samples x n_features x n_features
            # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :])
            preds = np.array(
                [
                    X[sample].dot(betas[bootstrap, sample, :, target_feat])  # + mus[bootstrap, j, i]
                    for sample in range(len(X))
                ]
            )
            residuals = X[:, target_feat] - preds
            mses[bootstrap, :] += residuals**2 / (X.shape[-1])
    if not individual_preds:
        mses = np.mean(mses, axis=0)
    return mses

def measure_recovery(W_true, W):
    # Assumes W has a prefix dimension for bootstraps
    recovery_errs = []
    for bootstrap in range(len(W)):
        for sample in range(len(W[bootstrap])):
            recovery_errs.append(np.linalg.norm(W_true[sample] - W[bootstrap][sample], ord=2))
    return np.mean([recovery_errs])

def make_dag(p, n_nonempty=4):
    # # n_nonempty = number of non-empty cells in W
    # coeffs = [np.random.choice([np.random.uniform(-1,-0.5), np.random.uniform(0.5,1)]) for i in range(n_nonempty)]
    # c * coeffs[assign_c_idx.index((i,j))]
    #create upper triangular
    tuples = [(i, j) for i in range(p) for j in range(p) if i < j]
    
    #randomly flip to account for monodirectinality
    tuples = [(t[1],t[0]) if np.random.uniform(0,1) > 0.5 else t for t in tuples]

    #select n_nonempty
    assign_c_idx = random.sample(tuples, n_nonempty)

    w = np.zeros((p, p))

    for i in range(p):
        for j in range(p):
            if (i,j) in assign_c_idx:
                w[i, j] = 1
                # w[0, 1] = 1
                # w[2, 1] = 1
                # w[3, 1] = 1
                # w[3, 2] = 1
    if is_dag(w):
        return w
    else:
        return make_dag(p, n_nonempty)

def generate_WC(p, n, n_clusters = 5, mode='linear'):
    '''Generate W and C for a given p, n, and mode.'''

    W = np.zeros((n, p, p))
    
    if mode == 'cluster':
        
        C = np.zeros((n, 1))
        cs = [np.random.normal(0, 1) for i in range(n_clusters)]
        centroid_ws = []
        for i in range(n_clusters):
            centroid_ws.append(make_dag(p))

        for i in range(n):
            c_idx = np.random.choice(len(centroid_ws))
            masked_std = np.random.normal(0, 0.2, size=(p, p)) * (centroid_ws[c_idx] != 0)            
            W[i] = centroid_ws[c_idx] + masked_std
            C[i,0] = c_idx

    elif mode == 'linear':
        dag = make_dag(p)

        C = np.random.normal(0, 1, size=(n, 1))
        coeffs = [np.random.choice([np.random.uniform(-1,-0.5), np.random.uniform(0.5,1)]) for i in range(int(dag.sum()))]
        
        for i in range(n):
            #Each non-empty = c * coeffs[assign_c_idx.index((i,j))]
            assign_c_idx = [(i,j) for i in range(p) for j in range(p) if dag[i,j] == 1]
            for j in range(len(assign_c_idx)):
                W[i, assign_c_idx[j][0], assign_c_idx[j][1]] = C[i,0] * coeffs[j]

    else:
        assert "Error, mode must be 'cluster' or 'linear'"
    
    return W, C

### Dataset Generation

In [None]:
# TODO: Sweep over graph types.
# TODO: Sweep over factors
# TODO: Sweep over signal-to-noise ratio.

import pickle as pkl

data_dir = 'tuning_directory'

os.makedirs(data_dir,exist_ok=True)

def get_data():
    n_data_gens = 1
    ns = [100, 1000]
    ps = [4, 10]
    graph_types = ["gauss", "exp", "gumbel", "uniform", "logistic", "poisson"]
    W_mode = ['cluster', 'linear']
    n_clusters = 5
    data = []
    data_line = ""
    # with open("data2/data_infodsfds.csv", 'w') as out_file:
    for n in ns:
        for p in ps:
            # for noise in noise_scales:
                # for graph_type in graph_types:
            for mode in W_mode:
                for data_gen in range(n_data_gens):
                    data_line = f"{n}, {p}, {W_mode}, {data_gen}"
                    W, C = generate_WC(p, n, n_clusters, mode=mode)
                    X = np.array([simulate_linear_sem(w, n_samples=1, sem_type='uniform', noise_scale=0.0)[0] for w in W])
                    data.append((C, X, W))
                    # print(data_line, file=out_file)
    # # save data
    pkl.dump(data, open(f"{data_dir}/{data_dir}_data.pkl", 'wb'))
    return data

# datas = get_data()
datas = pkl.load(open(f"{data_dir}/{data_dir}_data.pkl", 'rb'))


### Training Loop

In [None]:
def train_and_evaluate(config, data, strat_config, checkpoint_dir='checkpoints/'):
    C,X,W = data
    
    config.update(strat_config)

    n = W.shape[0]
    p = W.shape[-1]
    k = config['ks']
    n_fit_iters = 1

    header = "n, p, k, data_gen, fit_iter, "
    header += "recovery_pop_train, recovery_pop_test, mse_pop_train, mse_pop_test, "
    header += "recovery_cluster_train, recovery_cluster_test, mse_cluster_train, mse_cluster_test, "
    header += "recovery_notmad_notears_train, recovery_notmad_notears_test, mse_notmad_notears_train, mse_notmad_notears_test, "
    header += "recovery_notmad_dagma_train, recovery_notmad_dagma_test, mse_notmad_dagma_train, mse_notmad_dagma_test, "
    header += "recovery_notmad_poly_train, recovery_notmad_poly_test, mse_notmad_poly_train, mse_notmad_poly_test"

    C_train, C_test, X_train, X_test, W_train, W_test = train_test_split(C, X, W, test_size=0.3)

    recovery_train = lambda W_pred: measure_recovery(W_train, W_pred)
    recovery_test = lambda W_pred: measure_recovery(W_test, W_pred)
    mse_train = lambda W_pred: np.mean(measure_mses(W_pred, X_train))
    mse_test = lambda W_pred: np.mean(measure_mses(W_pred, X_test))
    results_string = lambda W_pred_train, W_pred_test: f", {recovery_train(W_pred_train)}, {recovery_test(W_pred_test)}, {mse_train(W_pred_train)}, {mse_test(W_pred_test)}"
    
    test_mses = []
    test_recoveries = []
    train_mses = []
    train_recoveries = []
    # results_line = ""
    for fit_iter in range(n_fit_iters):

        if config['loss_type'] == 'NOTEARS':
            loss_kwargs = {
                'archetype_alpha': 0.0,
                'archetype_rho': 0.0,
                'archetype_use_dynamic_alpha_rho':True,
                'archetype_tol': config['tol'],
                
                'sample_specific_tol': config['tol'],
                'sample_specific_alpha': config['alpha'],
                'sample_specific_rho': config['rho'],
                'sample_specific_use_dynamic_alpha_rho': True
            }
        elif config['loss_type'] == 'DAGMA':
            loss_kwargs = {
                'archetype_alpha': 0.0,
                'archetype_s': 0.0,
                'archetype_use_dynamic_alpha_rho':False,
                'sample_specific_alpha': config['alpha'],
                'sample_specific_s': config['s'],
                'sample_specific_use_dynamic_alpha_rho': False,
            }
        elif config['loss_type'] == 'poly':
            loss_kwargs = {}

        cbn = ContextualizedBayesianNetworks(
            encoder_type=config['encoder_types'],
            num_archetypes=k,
            num_factors=config['n_factors'],
            archetype_dag_loss_type=config['loss_type'],
            sample_specific_dag_loss_type=config['loss_type'],
            **loss_kwargs,
            n_bootstraps=3,
            learning_rate=1e-3,
        )
        cbn.fit(C_train, X_train, max_epochs=100)
        
        cbn_preds_train = cbn.predict_networks(C_train, individual_preds=True)
        cbn_preds_test = cbn.predict_networks(C_test, individual_preds=True)
        
        test_mses.append(mse_test(cbn_preds_test))
        test_recoveries.append(recovery_test(cbn_preds_test))
        train_mses.append(mse_train(cbn_preds_train))
        train_recoveries.append(recovery_train(cbn_preds_train))
    
    train_mse = np.mean(train_mses)
    train_recovery = np.mean(train_recoveries)
    test_mse = np.mean(test_mses)
    test_recovery = np.mean(test_recoveries)
    
    tune.report(train_mse=train_mse, train_recovery=train_recovery,test_mse=test_mse, test_recovery=test_recovery)
    return cbn

In [None]:
#hyperparameter groups
losses = ["NOTEARS", "DAGMA", "poly"]
encoder_types = ['ngam', 'mlp']
ks = [4, 8, 16]

stratify_by = {
    'loss_type': losses,
    'encoder_types': encoder_types,
    'ks':ks
}

n_runs = 3
data_dir = 'tuning_directory'


In [None]:
# Print NOTMAD params

def debug_loop():
    default_config = {
        "encoder_types": "ngam",
        'use_dynamic_alpha_rho': False,
        "ks": 4,
        "tol":0.1,
        "n_factors": 2,
        "alpha": 0.01,
        "rho": 0.01,
        "s": 0.65,
    }
    cbns = {}

    for loss in ["DAGMA", "NOTEARS", "poly"]:
        cf = default_config.copy()
        cf['loss_type'] = loss
        data = datas[0]
        cbns[loss] = train_and_evaluate(cf, data, {})
    
    return cbns

# cbns = debug_loop()

# for loss, cbn in cbns.items():
#     print(loss)
#     print(f'ss params: {cbn.models[0].ss_dag_params}')
#     print(f'arch params: {cbn.models[0].arch_dag_params}')
#     print('')

In [None]:
# [3 x (500) num_models / samples] x [(2 + 3 + 2) hps] x [20 datasets] = 210,000 runs


for d in range(len(datas)):
    for strat in stratify_by.keys():
        config = {
            "loss_type": tune.choice(losses),
            "encoder_types": tune.choice(encoder_types),
            "ks": tune.choice(ks),
            "n_factors": tune.sample_from(lambda spec: int(np.random.randint(0,spec.config.ks))),
            "alpha": tune.uniform(0.001, 0.1),
            "rho": tune.uniform(0.001, 0.1),
            "s": tune.uniform(0.001, 5),
        }
        
        del config[strat]

        if strat == 'ks':
            del config['n_factors']

        from ray.tune.schedulers import ASHAScheduler
        from ray.tune import CLIReporter
        from ray.air import RunConfig

        for instance in stratify_by[strat]:
            if strat == 'ks':
                config['n_factors'] = tune.sample_from(lambda spec: int(np.random.randint(0,instance)))
            
            for i in range (n_runs):
                
                if not os.path.exists(f'{data_dir}/'): os.mkdir(f'{data_dir}/')
                if os.path.exists(f'{data_dir}/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle'): 
                    with open(f'{data_dir}/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle', 'rb') as handle:
                        r = pkl.load(handle)
                        try:
                            results = r.get_best_result()
                            #redo broken runs
                            if not len(results.metrics.keys()) == 1: continue
                            print(f'redoing broken run: tune_analysis_d{d}_s{strat}_ins{instance}_i{i}')
                        except:
                            pass

                reporter = CLIReporter(max_progress_rows=10)
                reporter.add_metric_column("test_mse")
                reporter.add_metric_column("test_recovery")
                reporter.add_metric_column("train_mse")
                reporter.add_metric_column("train_recovery")
                
                tuner = tune.Tuner(
                    functools.partial(train_and_evaluate, data=(datas[d]), strat_config={strat: instance}),
                    param_space=config,
                    run_config=RunConfig(
                        progress_reporter=reporter
                    ),
                    tune_config=tune.TuneConfig(
                        num_samples=500,
                        scheduler=HyperBandScheduler(),
                        metric="test_mse",
                        mode='min',
                    ),
                )
                
                results = tuner.fit()

                with open(f'{data_dir}/tune_analysis_d{d}_s{strat}_ins{instance}_i{i}.pickle', 'wb') as handle:
                    pkl.dump(results, handle)
