In [1]:
import argparse, os
from pathlib import Path

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import kaiming_uniform_
from torch.optim.lr_scheduler import LambdaLR

In [2]:
import os
DEVICE = torch.device(os.getenv('OP_DEVICE', 'cuda' if torch.cuda.is_available() else 'cpu'))
DTYPE  = torch.float64


In [4]:
path = ''

In [5]:
# -----------------------------------------------------------------------------
#  Hard‑wired optimiser hyper‑parameters
# -----------------------------------------------------------------------------
learn_initial = 0.0001
learn_middle = 0.000001
learn_final = 1e-7

layers = [50,50,50]

epochs_initial = 30
epochs_middle = epochs_initial + 10
epochs_total = epochs_middle + 10

weight = 0.7
neg_weight = 10.0
batch_size = 128

seed = 33

In [6]:
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [7]:
# -----------------------------------------------------------------------------
#  Hard‑wired model parameters
# -----------------------------------------------------------------------------
crra = 5.0
beta = 0.96
alpha = 0.5

psi = 7.08
pi_star = 1.0484

rho_y = 0.8118
eta = 0.0347
sigma_y = eta / np.sqrt(1.0 - (rho_y ** 2.0))

mu = 0.0
y_ubnd = mu + 3.0 * sigma_y
y_lbnd = mu - 3.0 * sigma_y

b_star_ubnd = 0.3
b_star_lbnd = -0.1

b_tilde_ubnd = 0.3
b_tilde_lbnd = -0.1

kappa = 40.0

r = 0.04
R = np.exp(r)
delta = 0.757

qstar = 1.0 / (R - delta)

EPS = 1e-12

In [9]:
#n_int sets the number of nodes in the Gauss-Hermite quadrature used to approximate expectations
n_int = 19

x_int_norm, w_int = np.polynomial.hermite.hermgauss(n_int)
w_int = w_int/np.sqrt(np.pi)
x_int_norm = x_int_norm * np.sqrt(2)

w_int = torch.from_numpy(w_int).to(DEVICE)
x_int_norm = torch.from_numpy(x_int_norm).to(DEVICE)

innovation_i = x_int_norm.view(1, n_int, 1)
weight_i = w_int.view(1, n_int, 1)

In [10]:
def print_gpu_memory():
    
    num_gpus = torch.cuda.device_count()
    for i in range(num_gpus):
        # Get the current GPU's total memory and memory allocated
        total_memory = torch.cuda.get_device_properties(i).total_memory
        allocated_memory = torch.cuda.memory_allocated(i)
        cached_memory = torch.cuda.memory_reserved(i)

        # Convert bytes to GB for easier interpretation
        total_memory_gb = total_memory / (1024 ** 3)
        allocated_memory_gb = allocated_memory / (1024 ** 3)
        cached_memory_gb = cached_memory / (1024 ** 3)
        free_memory_gb = total_memory_gb - (allocated_memory_gb + cached_memory_gb)

        print(f"GPU {i}:")
        print(f"  Total Memory: {total_memory_gb:.2f} GB")
        print(f"  Allocated Memory: {allocated_memory_gb:.2f} GB")
        print(f"  Cached Memory: {cached_memory_gb:.2f} GB")
        print(f"  Free Memory (approx.): {free_memory_gb:.2f} GB")

In [11]:
class NN_pi(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Sigmoid())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
        
        pred = pred * (1.1 - 1.0) + 1.0
                    
        return pred

In [12]:
class NN_foreign(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Identity())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
                    
        return pred

In [13]:
class NN_local(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Identity())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
                    
        return pred

In [14]:
class NN_val(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Identity())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
                    
        return pred

In [15]:
class NN_qtilde(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Sigmoid())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
        
        pred = pred * 10.0
                    
        return pred

In [16]:
class NN_q(nn.Module):
    def __init__(self, in_szs, out_szs, layers):
        super().__init__()

        layerlist = []
        n_in = in_szs
        
        for i in layers:
            kaiming_uniform_(nn.Linear(n_in, i).weight, nonlinearity='leaky_relu')
            layerlist.append(nn.Linear(n_in, i))
            layerlist.append(nn.LeakyReLU())
            n_in = i
        
        layerlist.append(nn.Linear(layers[-1], out_szs))
        layerlist.append(nn.Softplus())
        
        self.layers = nn.Sequential(*layerlist)
                
    def forward(self, x):
        
        pred = self.layers(x)
                                    
        return pred

In [17]:
def compute_error(X, nn_pi, nn_foreign, nn_local, nn_val, nn_qtilde, nn_q):
        
    n_data = X.shape[0]
    
    yt = X[:, 0:1]
    bstar = X[:, 1:2]
    btilde = X[:, 2:3]
        
    pi = nn_pi(X)
    bstar_prime = nn_foreign(X)
    btilde_prime = nn_local(X)
    val = nn_val(X)
        
    exp_val_prime = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    exp_qtilde_inside = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    
    exp_dval_prime_dbstar_prime = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    exp_dval_prime_dbtilde_prime = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    
    exp_dqtilde_dbstar_prime_inside = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    exp_dqtilde_dbtilde_prime_inside = torch.zeros((n_data, 1), device=DEVICE, dtype=torch.float64)
    
    # ------------------------------------------------------------------
    # Compute expectation terms using vectorized code
    # ------------------------------------------------------------------
    #Below B represents the batch size
    yt_b = yt[:,None]     # (B,1,1)
    
    m = torch.exp(-r - kappa * innovation_i * eta - 0.5 * (kappa ** 2.0) * (eta ** 2.0))
    
    yt_prime = torch.exp(rho_y * torch.log(yt_b) + innovation_i * eta)
    
    yt_prime = yt_prime.expand(-1, n_int, -1) # (B, n_int, 1)
    m = m.expand(-1, n_int, -1)
    
    bstar_p = bstar_prime.unsqueeze(1)
    btilde_p = btilde_prime.unsqueeze(1)
    
    bstar_p = bstar_p.expand(-1, n_int, -1)
    btilde_p = btilde_p.expand(-1, n_int, -1)
    
    X_prime = torch.cat((yt_prime, bstar_p, btilde_p), dim = -1)
    X_prime_flat = X_prime.reshape(-1, 3)  # (B*n_int , 3)
    
    val_prime, qtilde_inside, dval_prime_dbstar_prime, dval_prime_dbtilde_prime, dqtilde_dbstar_prime_inside, dqtilde_dbtilde_prime_inside = get_exp_prime(X_prime_flat, nn_pi, nn_foreign, nn_local, nn_val, nn_qtilde, nn_q)

    val_prime = val_prime.view(n_data, n_int, 1)
    qtilde_inside = qtilde_inside.view(n_data, n_int, 1)
    
    dval_prime_dbstar_prime = dval_prime_dbstar_prime.view(n_data, n_int, 1)
    dval_prime_dbtilde_prime = dval_prime_dbtilde_prime.view(n_data, n_int, 1)
    
    dqtilde_dbstar_prime_inside = dqtilde_dbstar_prime_inside.view(n_data, n_int, 1)
    dqtilde_dbtilde_prime_inside = dqtilde_dbtilde_prime_inside.view(n_data, n_int, 1)
    
    exp_val_prime = (weight_i * val_prime).sum(1)
    exp_qtilde_inside = (weight_i * m * qtilde_inside).sum(1)
    
    exp_dval_prime_dbstar_prime = (weight_i * dval_prime_dbstar_prime).sum(1)
    exp_dval_prime_dbtilde_prime = (weight_i * dval_prime_dbtilde_prime).sum(1)
    
    exp_dqtilde_dbstar_prime_inside = (weight_i * m * dqtilde_dbstar_prime_inside).sum(1)
    exp_dqtilde_dbtilde_prime_inside = (weight_i * m * dqtilde_dbtilde_prime_inside).sum(1)

    # ------------------------------------------------------------------
    # Use expectations to construct error terms
    # ------------------------------------------------------------------
        
    qtilde = exp_qtilde_inside
    dqtilde_dbstar_prime = exp_dqtilde_dbstar_prime_inside
    dqtilde_dbtilde_prime = exp_dqtilde_dbtilde_prime_inside
    
    ct_discriminant = ((btilde / pi) ** 2.0) - 4.0 * (bstar - qstar * (bstar_prime - delta * bstar) - qtilde * (btilde_prime - delta * (btilde / pi)) - yt)
    ct_discriminant_clamped = torch.clamp(ct_discriminant, min = EPS)
    numerator = (-btilde / pi) + (ct_discriminant_clamped ** 0.5)
    ct = 0.25 * (numerator ** 2.0)
    
    zeta = (1.0 / pi) * (ct ** (1.0 - alpha))
    zeta_one = ct ** (1.0 - alpha)
    dzeta_dpi = -(1.0 / (pi ** 2.0)) * (ct ** (1.0 - alpha))
    dzeta_dct = (1.0 - alpha) * (1.0 / pi) * (ct ** -alpha)
    
    c = ct ** alpha
    lambda_ct_numerator = (1.0 - beta) * (c ** -crra) * alpha * (ct ** (alpha - 1.0))
    lambda_ct_denominator = 1.0 + dzeta_dct * btilde
    lambda_ct = lambda_ct_numerator / lambda_ct_denominator
    
    lambda_pi_numerator = (1.0 - beta) * psi * (pi - pi_star)
    lambda_pi_denominator = (btilde / (pi ** 2.0)) * (ct ** (1.0 - alpha)) + delta * (btilde / (pi ** 2.0)) * qtilde
    lambda_pi = lambda_pi_numerator / lambda_pi_denominator
    
    q = (1.0 / zeta_one) * qtilde

    qtilde_pred = nn_qtilde(X)
    q_pred = nn_q(X)
    
    qtilde_detached = qtilde.detach()
    q_detached = q.detach()
    
    err_qtilde = qtilde_detached - qtilde_pred
    
    err_q = q_detached - q_pred
    
    err_val = (1.0 - beta) * (((c ** (1.0 - crra)) / (1.0 - crra)) - (psi / 2.0) * ((pi - pi_star) ** 2.0)) + beta * exp_val_prime - val
    
    errREE_c_pi = lambda_ct - lambda_pi
    
    errREE_bstar = lambda_ct * (qstar + dqtilde_dbstar_prime * (btilde_prime - delta * (btilde / pi))) + beta * exp_dval_prime_dbstar_prime
    
    errREE_btilde = lambda_ct * (qtilde + dqtilde_dbtilde_prime * (btilde_prime - delta * (btilde / pi))) + beta * exp_dval_prime_dbtilde_prime
    
    err_discriminant = torch.clamp(-ct_discriminant, min = 0.0)
    
    err_numerator = torch.clamp(-numerator, min = 0.0)
    
    return err_val, errREE_c_pi, errREE_bstar, errREE_btilde, err_qtilde, err_q, err_discriminant, err_numerator
    

In [18]:
def get_exp_prime(X_prime, nn_pi, nn_foreign, nn_local, nn_val, nn_qtilde, nn_q):
        
    n_data = X_prime.shape[0]
    
    yt_prime = X_prime[:, 0:1]
    bstar_prime = X_prime[:, 1:2]
    btilde_prime = X_prime[:, 2:3]
    
    pi_prime = nn_pi(X_prime)
    bstar_prime_next = nn_foreign(X_prime)
    btilde_prime_next = nn_local(X_prime)
    val_prime = nn_val(X_prime)
        
    qtilde_prime = nn_qtilde(X_prime)
    q_prime = nn_q(X_prime)
    
    ct_discriminant_prime = ((btilde_prime / pi_prime) ** 2.0) - 4.0 * (bstar_prime - qstar * (bstar_prime_next - delta * bstar_prime) - qtilde_prime * (btilde_prime_next - delta * (btilde_prime / pi_prime)) - yt_prime)
    ct_discriminant_prime_clamped = torch.clamp(ct_discriminant_prime, min = EPS)
    ct_prime = 0.25 * (((-btilde_prime / pi_prime) + (ct_discriminant_prime_clamped ** 0.5)) ** 2.0)
    
    zeta_prime = (1.0 / pi_prime) * (ct_prime ** (1.0 - alpha))
    dzeta_prime_dpi_prime = -(1.0 / (pi_prime ** 2.0)) * (ct_prime ** (1.0 - alpha))
    dzeta_prime_dct_prime = (1.0 - alpha) * (1.0 / pi_prime) * (ct_prime ** -alpha)
    
    qtilde_inside = zeta_prime * (1.0 + delta * q_prime)
    
    c_prime = ct_prime ** alpha
    lambda_ct_prime_numerator = (1.0 - beta) * (c_prime ** -crra) * alpha * (ct_prime ** (alpha - 1.0))
    lambda_ct_prime_denominator = 1.0 + dzeta_prime_dct_prime * btilde_prime
    lambda_ct_prime = lambda_ct_prime_numerator / lambda_ct_prime_denominator
    
    dval_prime_dbstar_prime = -lambda_ct_prime * (1.0 + qstar * delta)
    dval_prime_dbtilde_prime = -lambda_ct_prime * (zeta_prime + qtilde_prime * delta * (1.0 / pi_prime))
    
    (dct_prime_dbstar_prime, dct_prime_dbtilde_prime, dpi_prime_dbstar_prime, 
     dpi_prime_dbtilde_prime, dq_prime_dbstar_prime, dq_prime_dbtilde_prime) = get_autograd_derivative(X_prime, nn_pi, nn_foreign, nn_local, nn_qtilde, nn_q)
    
    dqtilde_dbstar_prime_inside = (dzeta_prime_dpi_prime * dpi_prime_dbstar_prime + dzeta_prime_dct_prime * dct_prime_dbstar_prime) * (1.0 + delta * q_prime) + zeta_prime * delta * dq_prime_dbstar_prime
    dqtilde_dbtilde_prime_inside = (dzeta_prime_dpi_prime * dpi_prime_dbtilde_prime + dzeta_prime_dct_prime * dct_prime_dbtilde_prime) * (1.0 + delta * q_prime) + zeta_prime * delta * dq_prime_dbtilde_prime
    
    return val_prime, qtilde_inside, dval_prime_dbstar_prime, dval_prime_dbtilde_prime, dqtilde_dbstar_prime_inside, dqtilde_dbtilde_prime_inside
    

In [19]:
def get_autograd_derivative(X_prime, nn_pi, nn_foreign, nn_local, nn_qtilde, nn_q):
        
    X_prime_auto = X_prime.detach().clone().requires_grad_(True)
    
    yt_prime_auto = X_prime_auto[:, 0:1]
    bstar_prime_auto = X_prime_auto[:, 1:2]
    btilde_prime_auto = X_prime_auto[:, 2:3]
    
    pi_prime_auto = nn_pi(X_prime_auto)
    bstar_prime_next_auto = nn_foreign(X_prime_auto)
    btilde_prime_next_auto = nn_local(X_prime_auto)
    qtilde_prime_auto = nn_qtilde(X_prime_auto)
    q_prime_auto = nn_q(X_prime_auto)
    
    ct_discriminant_prime_auto = ((btilde_prime_auto / pi_prime_auto) ** 2.0) - 4.0 * (bstar_prime_auto - qstar * (bstar_prime_next_auto - delta * bstar_prime_auto) - qtilde_prime_auto * (btilde_prime_next_auto - delta * (btilde_prime_auto / pi_prime_auto)) - yt_prime_auto)
    ct_discriminant_prime_auto_clamped = torch.clamp(ct_discriminant_prime_auto, min = EPS)
    numerator_prime_auto = (-btilde_prime_auto / pi_prime_auto) + torch.sqrt(ct_discriminant_prime_auto_clamped)
    ct_prime_auto = 0.25 * (numerator_prime_auto ** 2.0)
    
    ct_grads = torch.autograd.grad(
    outputs=ct_prime_auto,                      # shape (n_data, 1)
    inputs=X_prime_auto,             # shape (n_data, 4)
    grad_outputs=torch.ones_like(ct_prime_auto),  # each row's derivative
    create_graph=True,)[0]
    
    #Compute the derivatives of ct_prime
    dct_prime_dbstar_prime = ct_grads[:, 1:2]
    dct_prime_dbtilde_prime = ct_grads[:, 2:3]
    
    pi_grads = torch.autograd.grad(
    outputs=pi_prime_auto,                      # shape (n_data, 1)
    inputs=X_prime_auto,             # shape (n_data, 3)
    grad_outputs=torch.ones_like(pi_prime_auto),  # each row's derivative
    create_graph=True,)[0]
    
    #Compute the derivatives of pi_prime
    dpi_prime_dbstar_prime = pi_grads[:, 1:2]
    dpi_prime_dbtilde_prime = pi_grads[:, 2:3]                   # shape (n_data, 1)
    
    q_grads = torch.autograd.grad(
    outputs=q_prime_auto,                      # shape (n_data, 1)
    inputs=X_prime_auto,             # shape (n_data, 3)
    grad_outputs=torch.ones_like(q_prime_auto),  # each row's derivative
    create_graph=True,)[0]
    
    #Compute the derivatives of q_prime
    dq_prime_dbstar_prime = q_grads[:, 1:2]
    dq_prime_dbtilde_prime = q_grads[:, 2:3]

    
    return dct_prime_dbstar_prime, dct_prime_dbtilde_prime, dpi_prime_dbstar_prime, dpi_prime_dbtilde_prime, dq_prime_dbstar_prime, dq_prime_dbtilde_prime
    

In [20]:
def compute_loss(X, nn_pi, nn_foreign, nn_local, nn_val, nn_qtilde, nn_q):
        
    w_neg = neg_weight
    w_val = weight
    w_ERR = (1.0 -  w_val) / 5.0
    
    err_val, errREE_c_pi, errREE_bstar, errREE_btilde, err_qtilde, err_q, err_discriminant, err_numerator = compute_error(X, nn_pi, nn_foreign, nn_local, nn_val, nn_qtilde, nn_q)
        
    loss = w_ERR * torch.mean(errREE_c_pi ** 2) + w_ERR * torch.mean(errREE_bstar ** 2) + w_ERR * torch.mean(errREE_btilde ** 2) + w_ERR * torch.mean(err_qtilde ** 2) + w_ERR * torch.mean(err_q ** 2) + w_val * torch.mean(err_val ** 2) + w_neg * torch.mean(err_discriminant ** 2) + w_neg * torch.mean(err_numerator ** 2) 
    
    return loss

In [21]:
def get_data(num_points):
    
    yt = torch.FloatTensor(num_points, 1).uniform_(y_lbnd, y_ubnd)
    yt = torch.exp(yt)
    bstar = torch.FloatTensor(num_points, 1).uniform_(b_star_lbnd, b_star_ubnd)
    btilde = torch.FloatTensor(num_points, 1).uniform_(b_tilde_lbnd, b_tilde_ubnd)
    
    data = torch.cat((yt, bstar, btilde), dim = 1)
    data = data.to(device=DEVICE, dtype=torch.float64).to(DEVICE)
    
    return data


In [23]:
# ------------------------------------------------------------------
# Instantiate six neural nets
# ------------------------------------------------------------------
model_pi = NN_pi(3, 1, layers)
model_pi = model_pi.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

model_foreign = NN_foreign(3, 1, layers)
model_foreign = model_foreign.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

model_local = NN_local(3, 1, layers)
model_local = model_local.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

model_val = NN_val(3, 1, layers)
model_val = model_val.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

model_qtilde = NN_qtilde(3, 1, layers)
model_qtilde = model_qtilde.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

model_q = NN_q(3, 1, layers)
model_q = model_q.to(device=DEVICE, dtype=torch.float64).to(DEVICE)

# ------------------------------------------------------------------
# Set up optimisers and schedulers
# ------------------------------------------------------------------
optimizer_pi = torch.optim.AdamW(model_pi.parameters(), lr = learn_initial)
optimizer_foreign = torch.optim.AdamW(model_foreign.parameters(), lr = learn_initial)
optimizer_local = torch.optim.AdamW(model_local.parameters(), lr = learn_initial)
optimizer_val = torch.optim.AdamW(model_val.parameters(), lr = learn_initial)
optimizer_qtilde = torch.optim.AdamW(model_qtilde.parameters(), lr = learn_initial)
optimizer_q = torch.optim.AdamW(model_q.parameters(), lr = learn_initial)

# One LR scheduler shared across all nets (piece‑wise constant ratio)
def lr_lambda(epoch):
    if   epoch < epochs_initial:               
        return 1.0
    elif epoch < epochs_middle:      
        return learn_middle / learn_initial
    else:
        return learn_final / learn_initial

scheduler_pi = LambdaLR(optimizer_pi, lr_lambda = lr_lambda)
scheduler_foreign = LambdaLR(optimizer_foreign, lr_lambda = lr_lambda)
scheduler_local = LambdaLR(optimizer_local, lr_lambda = lr_lambda)
scheduler_val = LambdaLR(optimizer_val, lr_lambda = lr_lambda)
scheduler_qtilde = LambdaLR(optimizer_qtilde, lr_lambda = lr_lambda)
scheduler_q = LambdaLR(optimizer_q, lr_lambda = lr_lambda)

# ------------------------------------------------------------------
# Training loop 
# ------------------------------------------------------------------
for i in range(epochs_total):

    optimizer_pi.zero_grad()
    optimizer_foreign.zero_grad()
    optimizer_local.zero_grad()
    optimizer_val.zero_grad()
    optimizer_qtilde.zero_grad()
    optimizer_q.zero_grad()

    data = get_data(batch_size)

    loss = compute_loss(data, model_pi, model_foreign, model_local, model_val, model_qtilde, model_q)

    loss.backward()

    optimizer_pi.step()
    optimizer_foreign.step()
    optimizer_local.step()
    optimizer_val.step()
    optimizer_qtilde.step()
    optimizer_q.step()

    scheduler_pi.step()
    scheduler_foreign.step()
    scheduler_local.step()
    scheduler_val.step()
    scheduler_qtilde.step()
    scheduler_q.step()

    if i%20 == 0:
        print(f'Epoch {i} loss is {loss}')

    if i%1000 == 0:
        print_gpu_memory()

# ------------------------------------------------------------------
# Checkpoint models
# ------------------------------------------------------------------
base = Path(path or ".").expanduser().resolve()
out = base / "pickle"
out.mkdir(parents=True, exist_ok=True)
tag = (f"kappa{kappa}")

torch.save(model_pi.state_dict(),        out / f"{tag}_Model_pi.pt")
torch.save(model_foreign.state_dict(),   out / f"{tag}_Model_foreign.pt")
torch.save(model_local.state_dict(),     out / f"{tag}_Model_local.pt")
torch.save(model_val.state_dict(),       out / f"{tag}_Model_val.pt")
torch.save(model_qtilde.state_dict(),    out / f"{tag}_Model_qtilde.pt")
torch.save(model_q.state_dict(),         out / f"{tag}_Model_q.pt")


Epoch 0 loss is 122717.8708385917
Epoch 20 loss is 1.0152205848674115e+25
Epoch 40 loss is 16441355.624562748
