In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('..')))
print(sys.path)

from sogaPreprocessor import *
from producecfg import *
from smoothcfg import *
from libSOGA import *
from time import time

from utils import get_data, mean_squared_error, mean_squared_error_bayes, neg_log_likelihood, neg_log_likelihood_one
import vi_model
from parallel_loss import run_parallel

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import pyro.optim
from pyro.infer import SVI, Trace_ELBO
from pyro.distributions import constraints
import pyro.distributions as dist

torch.set_default_dtype(torch.float32)

from concurrent.futures import ProcessPoolExecutor

In [None]:

import torch
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import time

torch.set_default_dtype(torch.float64)

# Set up BNN parameters
bnn_one_pars = {
    'muw1': 1., 'muw2': -1., 'muw3': 1., 'muw4': -1., 'muw5': -1., 'muw6': 1., 'muw7': -1., 'muw8': -1.,  
    'mub1': -1., 'mub2': -1., 'mub3': -1., 'mub4': 1., 'mub5': 1.,
    'sigmaw1': 0.1, 'sigmaw2': 0.1, 'sigmaw3': 0.1, 'sigmaw4': 0.1, 'sigmaw5': 0.1, 'sigmaw6': 0.1, 'sigmaw7': 0.1, 'sigmaw8': 0.1,
    'sigmab1': 0.1, 'sigmab2': 0.1, 'sigmab3': 0.1, 'sigmab4': 0.1, 'sigmab5': 0.1,
}

for key, value in bnn_one_pars.items():
    bnn_one_pars[key] = torch.tensor(value, requires_grad=True)

# Compile and load config
compiledFile = compile2SOGA('../../programs/SOGA/Optimization/CaseStudies/bnn3.soga')
cfg = produce_cfg(compiledFile)
smooth_cfg(cfg)

# Optimizer
lr = 0.001
steps = 5000
optimizer = torch.optim.Adam([bnn_one_pars[k] for k in bnn_one_pars if k != 'x'], lr)

batch_size = 20
total_start = time.time()

time_cfg = 0

#clear the cfg creation time file
with open("cfg_creation_time.txt", "w") as f:
    f.write("")

# ✅ Function to compute one sample's loss (runs in a thread)
def compute_loss_one(j, X, Y, bnn_pars, compiledFile):
    time_created = time.time()
    #new_cfg = copy.deepcopy(cfg)  # Use a copy of the cfg to avoid modifying the original
    new_cfg = produce_cfg(compiledFile)  # Use a copy of the cfg to avoid modifying the original
    smooth_cfg(new_cfg)
    time_elapsed = time.time() - time_created
    #write in a file the time spent on cfg creation
    with open("cfg_creation_time.txt", "a") as f:
        f.write(f"Sample {j}: {time_elapsed:.4f} seconds\n")
    sampled_index = np.random.randint(0, len(Y.squeeze([-1,1])))
    yj = Y.squeeze([1])[sampled_index].to(torch.float64)
    xj = X.squeeze([-1,1])[sampled_index]
    bnn_pars['x'] = xj.requires_grad_(False)
    current_dist = start_SOGA(new_cfg, bnn_pars, pruning='ranking')
    return neg_log_likelihood_one(yj, current_dist)

# ✅ Main loop
for i in range(steps):
    optimizer.zero_grad()
    total_loss = 0
    losses = torch.tensor([])
    # Use threads to parallelize forward loss computation
    with ThreadPoolExecutor(max_workers=batch_size) as executor:
        futures = [executor.submit(compute_loss_one, j, X, Y, bnn_one_pars, compiledFile) for j in range(batch_size)]
        losses = torch.stack([f.result() for f in futures])

    #print(losses)
    total_loss = losses.sum()

    total_loss.backward(retain_graph=True)
    optimizer.step()
    for key in bnn_one_pars.keys():
        if 'sigma' in key and bnn_one_pars[key].item() < 1e-6:
            bnn_one_pars[key] = torch.tensor(1e-6, requires_grad=True)

    if i % int(steps / 10) == 0:
        out = f"loss: {total_loss.item()}"
        for key in bnn_one_pars:
            if key != 'x':
                out += f" {key}: {bnn_one_pars[key].item():.4f}"
        print(out)

total_end = time.time()
print("Optimization performed in", round(total_end - total_start, 3), "seconds")

In [None]:

y_means = []
cfg = produce_cfg(compiledFile)
smooth_cfg(cfg)
for j in range(len(Y.squeeze([-1,1]))):
    yj = Y.squeeze([1])[j].to(torch.float64)
    xj = X.squeeze([-1,1])[j]
    bnn_one_pars['x'] = xj.requires_grad_(False)

    current_dist = start_SOGA(cfg, bnn_one_pars, pruning='ranking')
    y_means.append(current_dist.gm.mean()[current_dist.var_list.index('y')].detach().numpy())

# plot training data
plt.plot(X.squeeze(-1).numpy(), Y.squeeze(-1).numpy(), "kx")
# plot 90% confidence level of predictions
#plt.fill_between(X.numpy().flatten(), percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
plt.plot(X.numpy().flatten(), y_means, "blue", ls="solid", lw=2.0)
#ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
plt.show()


In [None]:
#read the time spent on cfg creation from the file (that is in format f"Sample {j}: {time_elapsed:.4f} seconds\n"))
times_list = []
with open("cfg_creation_time.txt", "r") as f:
    cfg_creation_times = f.readlines()
for i, line in enumerate(cfg_creation_times):
    # Extract the time from the line
    time_spent = float(line.split(": ")[1].strip().split(" ")[0])
    times_list.append(time_spent)

# print the sum of all times
time_cfg = sum(times_list)
print("Total time spent on cfg creation:", time_cfg, "seconds")
    