## Imports

In [1]:
import copy
import numpy as np
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
from config import config
import os
import sys
current_dir = os.getcwd()
path = "C:\\Users\\eirik\\Documents\\Master\\ISLBBNN\\islbbnn"
# path = "C:\\you\\path\\to\\islbbnn\\folder\\here"
os.chdir(path)
import plot_functions as pf
import pipeline_functions as pip_func
sys.path.append('networks')
from flow_sigmoid_net import BayesianNetwork

os.chdir(current_dir) # set the working directory back to this one 

CPUs are used!


# Attain data

Problem:

$$y = x_1 + x_2 + 100 +\epsilon$$

where $\epsilon \sim N(0,0.01)$. 


Can make $x_3$ dependent on $x_1$. The depedence is defined in the following way:

\begin{align*}
 x_1 &\sim Unif(-10,10) \\
 x_3 &\sim Unif(-10,10) \\
 x_3 &= \text{dep}\cdot x_1 + (1-\text{dep})\cdot x_3
\end{align*}

## Pre process and batch size

In [2]:
# define parameters
HIDDEN_LAYERS = config['n_layers'] - 2 
epochs = config['num_epochs']
post_train_epochs = config['post_train_epochs']
dim = config['hidden_dim']
num_transforms = config['num_transforms']
n_nets = config['n_nets']
n_samples = config['n_samples']
lr = config['lr']
class_problem = config["class_problem"]
non_lin = config["non_lin"]
verbose = config['verbose']
save_res = config['save_res']
patience = config['patience']
SAMPLES = 1


# Get linear data, here a regression problem
y, X = pip_func.create_data_unif(n_samples, beta=[100,1,1,1,1], dep_level=0.0, classification=class_problem)

n, p = X.shape  # need this to get p 
print(n,p,dim)

# Define BATCH sizes
BATCH_SIZE = int((n*0.8)/100)
TEST_BATCH_SIZE = int(n*0.10) # Would normally call this the "validation" part (will be used during training)
VAL_BATCH_SIZE = int(n*0.10) # and this the "test" part (will be used after training)

TRAIN_SIZE = int((n*0.80)/100)
TEST_SIZE = int(n*0.10) # Would normally call this the "validation" part (will be used during training)
VAL_SIZE = int(n*0.10) # and this the "test" part (will be used after training)

NUM_BATCHES = TRAIN_SIZE/BATCH_SIZE

print(NUM_BATCHES)

assert (TRAIN_SIZE % BATCH_SIZE) == 0
assert (TEST_SIZE % TEST_BATCH_SIZE) == 0

40000 4 10
1.0


## Seperate a test set for later

In [3]:
# Split keep some of the data for validation after training
X, X_test, y, y_test = train_test_split(
    X, y, test_size=0.10, random_state=42)#, stratify=y)

test_dat = torch.tensor(np.column_stack((X_test,y_test)),dtype = torch.float32)

# Train network

## Device setup

In [4]:
# select the device and initiate model

# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

## Train, validate, and test network

In [5]:
all_nets = {}
metrics_several_runs = []
metrics_median_several_runs = []
for ni in range(n_nets):
    post_train = False
    print('network', ni)
    # Initate network
    torch.manual_seed(ni+42)
    net = BayesianNetwork(dim, p, HIDDEN_LAYERS, classification=class_problem, num_transforms=num_transforms).to(DEVICE)
    alphas = pip_func.get_alphas_numpy(net)
    nr_weights = np.sum([np.prod(a.shape) for a in alphas])
    print(nr_weights)

    optimizer = optim.Adam(net.parameters(), lr=lr)
    
    all_nll = []
    all_loss = []

    # Split into training and test set
    X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=1/9, random_state=ni)#, stratify=y)
            
    train_dat = torch.tensor(np.column_stack((X_train,y_train)),dtype = torch.float32)
    val_dat = torch.tensor(np.column_stack((X_val,y_val)),dtype = torch.float32)
    
    # Train network
    counter = 0
    highest_acc = 0
    best_model = copy.deepcopy(net)
    for epoch in range(epochs + post_train_epochs):
        if verbose:
            print(epoch)
        nll, loss = pip_func.train(net, train_dat, optimizer, BATCH_SIZE, NUM_BATCHES, p, DEVICE, nr_weights, post_train=post_train)
        nll_val, loss_val, ensemble_val = pip_func.val(net, val_dat, DEVICE, verbose=verbose, reg=(not class_problem), post_train=post_train)
        if ensemble_val >= highest_acc:
            counter = 0
            highest_acc = ensemble_val
            best_model = copy.deepcopy(net)
        else:
            counter += 1
        
        all_nll.append(nll)
        all_loss.append(loss)

        if epoch == epochs-1:
            post_train = True   # Post-train --> use median model 
            for name, param in net.named_parameters():
                for i in range(HIDDEN_LAYERS+1):
                    #if f"linears{i}.lambdal" in name:
                    if f"linears.{i}.lambdal" in name:
                        param.requires_grad_(False)

        if counter >= patience:
            break
        
    all_nets[ni] = net 
    # Results
    metrics, metrics_median = pip_func.test_ensemble(all_nets[ni],test_dat,DEVICE,SAMPLES=10, reg=(not class_problem)) # Test same data 10 times to get average 
    metrics_several_runs.append(metrics)
    metrics_median_several_runs.append(metrics_median)
    pf.run_path_graph(all_nets[ni], threshold=0.5, save_path=f"path_graphs/flow/prob/test{ni}", show=verbose)

if verbose:
    print(metrics)
m = np.array(metrics_several_runs)
m_median = np.array(metrics_median_several_runs)

network 0
54
0
loss 453.8760681152344
nll 18.871662139892578
density 0.73447059508827

val_loss: 609.0934, val_nll: 170.7656, val_ensemble: 0.9928, used_weights_median: 38

1
loss 110.76991271972656
nll 14.557331085205078
density 0.24655262684380566

val_loss: 264.5409, val_nll: 180.9322, val_ensemble: 0.9892, used_weights_median: 8

2
loss 27.08100700378418
nll 13.355836868286133
density 0.15052480423064143

val_loss: 173.4721, val_nll: 166.9564, val_ensemble: 0.9872, used_weights_median: 2

3
loss -25.55683708190918
nll 7.898965358734131
density 0.11723341923896913

val_loss: 96.5641, val_nll: 118.1691, val_ensemble: 0.9945, used_weights_median: 2

4
loss -43.99579620361328
nll 8.606891632080078
density 0.09364188144293924

val_loss: 105.6388, val_nll: 160.1157, val_ensemble: 0.9812, used_weights_median: 2

5
loss -61.675724029541016
nll 18.549068450927734
density 0.08144892845733988

val_loss: 24.1445, val_nll: 116.2475, val_ensemble: 0.9905, used_weights_median: 2

6
loss -99.64990

In [6]:
pf.run_path_graph_weight(net, save_path="path_graphs/flow/weight/temp", show=True)

In [7]:
pip_func.get_alphas_numpy(net)

[array([[7.87760109e-06, 5.86484530e-05, 1.44919886e-05, 5.82302595e-03],
        [1.56587092e-04, 1.17364703e-04, 4.38525167e-05, 4.08909703e-03],
        [8.84148176e-05, 4.00674879e-04, 3.64325788e-05, 4.68287617e-03],
        [8.73189492e-05, 4.62921598e-05, 2.91525266e-05, 3.99240712e-03],
        [2.55780215e-05, 9.29470843e-05, 2.99269177e-05, 3.89468390e-03],
        [1.11794143e-04, 9.70581095e-05, 5.83579531e-05, 3.99460318e-03],
        [4.54293286e-06, 3.44829532e-05, 1.42100125e-05, 4.20172932e-03],
        [7.03233745e-05, 4.51400156e-05, 1.49041452e-05, 4.69373632e-03],
        [9.17788639e-05, 7.55830297e-06, 1.10179717e-05, 5.06829424e-03],
        [1.11250309e-04, 9.38020530e-05, 6.04475790e-05, 4.02996410e-03]],
       dtype=float32),
 array([[2.6273625e-07, 7.3012620e-07, 5.5216583e-06, 2.7891547e-06,
         7.3190307e-07, 2.1345820e-06, 2.8426817e-05, 2.7398286e-07,
         2.4209069e-06, 6.6183979e-06, 9.9999928e-01, 1.0000000e+00,
         1.0123221e-06, 3.386

In [8]:
pip_func.weight_matrices_numpy(net)

[array([[-9.99995787e-03,  2.72842357e-04,  2.57626530e-02,
          4.77926172e-02],
        [-3.70672010e-02,  4.95574763e-03,  5.47802337e-02,
         -6.41293347e-01],
        [-9.68932727e-05,  1.68750086e-03, -3.09986770e-02,
         -1.33914024e-01],
        [ 1.34886056e-03,  3.30697112e-02,  1.88033022e-02,
         -5.24309158e-01],
        [ 2.24700174e-03,  3.79590429e-02, -2.63027381e-02,
         -3.37164104e-01],
        [ 3.83165777e-02,  3.82795185e-02,  2.28392659e-03,
         -1.11174941e-01],
        [-6.39542788e-02,  2.53586303e-02,  1.40828174e-02,
         -4.00992483e-01],
        [ 2.74011381e-02, -1.80404400e-03, -3.49637866e-02,
         -3.44291069e-02],
        [ 1.03720136e-01,  9.47691206e-06,  5.51327467e-02,
         -5.99362627e-02],
        [-7.92662555e-04,  4.97676283e-02,  5.12579605e-02,
         -8.50109696e-01]], dtype=float32),
 array([[-1.7587095e-04,  9.1777292e-06,  1.1637784e-03,  1.5023402e-05,
          1.3170655e-04,  2.3475538e-04,