## Imports

In [1]:
import copy
import numpy as np
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
from config import config
import os
import sys
current_dir = os.getcwd()
path = "C:\\Users\\eirik\\Documents\\Master\\ISLBBNN\\islbbnn"
# path = "C:\\you\\path\\to\\islbbnn\\folder\\here"
os.chdir(path)
import plot_functions as pf
import pipeline_functions as pip_func
sys.path.append('networks')
from lrt_sigmoid_net import BayesianNetwork

os.chdir(current_dir) # set the working directory back to this one 

# Attain data

Problem:

$$y = x_1 + x_2 + x_1\cdot x_2 + x_1^2 + x_2^2 + 100 +\epsilon$$

where $\epsilon \sim N(0,0.01)$. 


Can make $x_3$ dependent on $x_1$. The depedence is defined in the following way:

\begin{align*}
 x_1 &\sim Unif(-10,10) \\
 x_3 &\sim Unif(-10,10) \\
 x_3 &= \text{dep}\cdot x_1 + (1-\text{dep})\cdot x_3
\end{align*}

## Pre process and batch size

In [2]:
# define parameters
HIDDEN_LAYERS = config['n_layers'] - 2 
epochs = config['num_epochs']
post_train_epochs = config['post_train_epochs']
dim = config['hidden_dim']
n_nets = config['n_nets']
n_samples = config['n_samples']
lr = config['lr']
class_problem = config["class_problem"]
non_lin = config["non_lin"]
verbose = config['verbose']
save_res = config['save_res']
patience = config['patience']
SAMPLES = 1


# Get linear data, here a regression problem
y, X = pip_func.create_data_unif(n_samples, beta=[100,1,1,1,1], dep_level=0.0, classification=class_problem, non_lin=non_lin)

n, p = X.shape  # need this to get p 
print(n,p,dim)

# Define BATCH sizes
BATCH_SIZE = int((n*0.8)/100)
TEST_BATCH_SIZE = int(n*0.10) # Would normally call this the "validation" part (will be used during training)
VAL_BATCH_SIZE = int(n*0.10) # and this the "test" part (will be used after training)

TRAIN_SIZE = int((n*0.80)/100)
TEST_SIZE = int(n*0.10) # Would normally call this the "validation" part (will be used during training)
VAL_SIZE = int(n*0.10) # and this the "test" part (will be used after training)

NUM_BATCHES = TRAIN_SIZE/BATCH_SIZE

print(NUM_BATCHES)

assert (TRAIN_SIZE % BATCH_SIZE) == 0
assert (TEST_SIZE % TEST_BATCH_SIZE) == 0

40000 4 10
1.0


## Seperate a test set for later

In [3]:
# Split keep some of the data for validation after training
X, X_test, y, y_test = train_test_split(
    X, y, test_size=0.10, random_state=42)#, stratify=y)

test_dat = torch.tensor(np.column_stack((X_test,y_test)),dtype = torch.float32)

# Train network

## Device setup

In [4]:
# select the device and initiate model

# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

## Train, validate and test network

In [5]:
all_nets = {}
metrics_several_runs = []
metrics_median_several_runs = []
for ni in range(n_nets):
    post_train = False
    print('network', ni)
    # Initate network
    torch.manual_seed(ni+42)
    net = BayesianNetwork(dim, p, HIDDEN_LAYERS, classification=class_problem).to(DEVICE)
    alphas = pip_func.get_alphas_numpy(net)
    nr_weights = np.sum([np.prod(a.shape) for a in alphas])
    print(nr_weights)

    optimizer = optim.Adam(net.parameters(), lr=lr)
    
    all_nll = []
    all_loss = []

    # Split into training and test set
    X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=1/9, random_state=ni)#, stratify=y)
            
    train_dat = torch.tensor(np.column_stack((X_train,y_train)),dtype = torch.float32)
    val_dat = torch.tensor(np.column_stack((X_val,y_val)),dtype = torch.float32)
    
    # Train network
    counter = 0
    highest_acc = 0
    best_model = copy.deepcopy(net)
    for epoch in range(epochs + post_train_epochs):
        if verbose:
            print(epoch)
        nll, loss = pip_func.train(net, train_dat, optimizer, BATCH_SIZE, NUM_BATCHES, p, DEVICE, nr_weights, post_train=post_train)
        nll_val, loss_val, ensemble_val = pip_func.val(net, val_dat, DEVICE, verbose=verbose, reg=(not class_problem))
        if ensemble_val >= highest_acc:
            counter = 0
            highest_acc = ensemble_val
            best_model = copy.deepcopy(net)
        else:
            counter += 1
        
        all_nll.append(nll)
        all_loss.append(loss)

        if epoch == epochs-1:
            post_train = True   # Post-train --> use median model 
            for name, param in net.named_parameters():
                for i in range(HIDDEN_LAYERS+1):
                    #if f"linears{i}.lambdal" in name:
                    if f"linears.{i}.lambdal" in name:
                        param.requires_grad_(False)

        if counter >= patience:
            break
        
    all_nets[ni] = net 
    # Results
    metrics, metrics_median = pip_func.test_ensemble(all_nets[ni], test_dat, DEVICE, SAMPLES=10, reg=(not class_problem)) # Test same data 10 times to get average 
    metrics_several_runs.append(metrics)
    metrics_median_several_runs.append(metrics_median)
    pf.run_path_graph(all_nets[ni], threshold=0.5, save_path=f"path_graphs/lrt/prob/test{ni}", show=verbose)

if verbose:
    print(metrics)
m = np.array(metrics_several_runs)
m_median = np.array(metrics_median_several_runs)

network 0
54
0
loss 925.3873291015625
nll 219.50637817382812
density 0.9683086629267093

val_loss: 3377.9114, val_nll: 2649.6443, val_ensemble: 0.6095, used_weights_median: 54

1
loss 458.15728759765625
nll 152.3841094970703
density 0.5228981839285957

val_loss: 2052.9136, val_nll: 1764.6082, val_ensemble: 0.8468, used_weights_median: 22

2
loss 274.00311279296875
nll 124.36599731445312
density 0.26022213918191417

val_loss: 1627.4569, val_nll: 1486.9828, val_ensemble: 0.8702, used_weights_median: 11

3
loss 244.03372192382812
nll 108.3678970336914
density 0.23993244667158084

val_loss: 1452.2245, val_nll: 1319.9252, val_ensemble: 0.8800, used_weights_median: 11

4
loss 220.27041625976562
nll 98.10523223876953
density 0.216781132992495

val_loss: 1354.3467, val_nll: 1234.6224, val_ensemble: 0.8798, used_weights_median: 10

5
loss 216.26412963867188
nll 97.57791900634766
density 0.2131863366122599

val_loss: 1311.7034, val_nll: 1194.5437, val_ensemble: 0.8758, used_weights_median: 10

6

TODO: Seems like val_nll and val_loss is the same, but this is not true. Fix this...

## Function for plotting weight magnitude

In [6]:
pf.run_path_graph_weight(net, save_path="path_graphs/lrt/weight/temp", show=True)

In [7]:
pip_func.get_alphas_numpy(net)

[array([[4.9000159e-03, 9.7776996e-04, 9.7243045e-04, 7.1462744e-04],
        [1.0562978e-03, 1.8786686e-03, 1.8330807e-03, 2.5666223e-03],
        [1.0000000e+00, 9.9999404e-01, 3.0152025e-04, 2.5752204e-04],
        [7.9982483e-04, 7.6221547e-04, 8.2809562e-03, 3.4691701e-03],
        [1.5113523e-03, 6.9436670e-04, 8.6942367e-04, 8.9786836e-04],
        [1.3443066e-03, 7.3418405e-04, 9.5240952e-04, 8.4099185e-04],
        [1.4998798e-03, 9.6966058e-04, 1.0888408e-03, 7.4326072e-04],
        [4.0966843e-04, 9.9994969e-01, 3.7607978e-04, 4.1295847e-04],
        [9.9755563e-03, 1.0411217e-03, 2.9606109e-03, 7.0231557e-03],
        [3.8803008e-04, 1.0000000e+00, 3.3020217e-04, 3.1847859e-04]],
       dtype=float32),
 array([[1.5701920e-04, 7.6811435e-04, 1.0000000e+00, 8.2615798e-04,
         7.1956799e-04, 8.0848840e-04, 7.4273249e-04, 1.0000000e+00,
         8.0727239e-04, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
         3.5897046e-04, 3.6562642e-04]], dtype=float32)]

In [8]:
pip_func.weight_matrices_numpy(net)

[array([[ 1.0308874e+00,  8.6741424e-01,  2.7092358e-02, -1.1303907e-03],
        [ 2.3953363e-02, -2.8360793e-03,  3.4015826e-03, -3.4925495e-03],
        [ 4.2224801e-01,  2.1433486e-01,  1.5572631e-02, -1.9895034e-03],
        [ 3.2922179e-01,  8.5675091e-01,  9.6578622e-01, -8.2574207e-01],
        [ 6.5951347e-01,  1.0355248e+00,  8.8233262e-02, -4.3807101e-01],
        [-6.5660334e-01,  3.8346007e-01,  5.9250742e-02,  9.2551476e-01],
        [-6.4604022e-02, -1.0967822e+00, -1.0869149e+00,  1.5684751e-01],
        [ 2.0661561e-02, -6.1666161e-01, -1.4194698e-05,  5.8925105e-03],
        [ 7.7836841e-01,  7.4955505e-01,  1.7783241e-02,  3.0547616e-03],
        [-2.1044217e-02, -1.1262019e+00, -9.3743455e-04, -2.0147987e-04]],
       dtype=float32),
 array([[ 8.0842662e-01,  6.9610514e-02,  4.2719688e+01,  2.1565706e-02,
          3.1601645e-02,  3.3257671e-02,  2.4771078e-02, -1.6498758e+01,
         -4.6596909e-04, -1.9760948e+01, -1.8349903e+00, -2.6289647e+00,
         -9.53128