In [None]:
import torch
import torch.nn as nn
import numpy as np
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import wandb

from src.neural_net import *
from trainer import *
from prb_def import *


# Data generation

In [2]:
ntrain = 50000
ntest = 1000
batch_size = 500

# Create datasets
train_loader, test_loader = create_data_loaders(ntrain, ntest, batch_size)

# Create the model

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SL_model = MLPWithSTE(insize=2*nx+1, outsize=ny,
                bias=True,
                linear_map=torch.nn.Linear,
                nonlin=nn.ReLU,
                hsizes=[128] * 2)

CRT_model = MLPWithSTE(insize=2*nx+1+ny, outsize=ny,
                bias=True,
                linear_map=torch.nn.Linear,
                nonlin=nn.ReLU,
                hsizes=[128] * 2)

slack_weight = 1e3
constraint_weight = 1e6
supervised_weight = 1e5
cvx_layer = example_QP(nx=nx, ny=ny, penalty="l1", rho1=slack_weight)

Model = SSL_MIQP_corrected(CRT_model, SL_model, cvx_layer, nx, ny, device=device)

In [9]:
training_params = {}
training_params['TRAINING_EPOCHS'] = int(1)
training_params['CHECKPOINT_AFTER'] = int(20)
training_params['LEARNING_RATE'] = 1e-3
training_params['WEIGHT_DECAY'] = 1e-5
training_params['PATIENCE'] = 5

Model.train_SL(ground_truth_solver, train_loader, test_loader, training_params)


[epoch 1 | step 1] training loss = 0.2200, validation loss = 0.2250
[epoch 1 | step 20] training loss = 0.0325, validation loss = 0.0383
[epoch 1 | step 40] training loss = 0.0185, validation loss = 0.0193
[epoch 1 | step 60] training loss = 0.0135, validation loss = 0.0143
[epoch 1 | step 80] training loss = 0.0090, validation loss = 0.0108
[epoch 1 | step 100] training loss = 0.0110, validation loss = 0.0110


In [10]:
# Then train the correction NN with self-supervised learning
training_params = {}
training_params['TRAINING_EPOCHS'] = int(5)
training_params['CHECKPOINT_AFTER'] = int(20)
training_params['LEARNING_RATE'] = 1e-3
training_params['WEIGHT_DECAY'] = 1e-5
training_params['PATIENCE'] = 10    

slack_weight = 1e3
constraint_weight = 1e6
supervised_weight = 1e5
loss_weights = [0.0, slack_weight, constraint_weight, supervised_weight]
def quad_fcn(x, y, theta): 
    p = theta[:, :nx]
    return (x**2).sum(dim=1) + (p*x).sum(dim=1)
def y_sum_con(y, theta):
    return torch.sum(y, dim=-1) - 1.0  # should be <= 0
y_cons = [y_sum_con]

Model.train_SSL(ground_truth_solver, train_loader, test_loader, training_params, loss_weights, 
            loss_scale = 1e2, obj_fcn=quad_fcn, y_cons=y_cons)

Validation for the supervised learning model: 
obj_val = -98.4763, avg_opt_gap = 0.2439, slack_pen = 0.0000, y_sum_penalty = 0.0030, supervised_loss = 0.0110, 
__________________________________________________
[epoch 1 | step 1] validation: loss = 2.2094, obj_val = -2.1079, avg_opt_gap = 97.8604, slack_pen = 0.0000, y_sum_penalty = 0.0000, supervised_loss = 0.2433, 
[epoch 1 | step 20] validation: loss = 0.2793, obj_val = -92.3527, avg_opt_gap = 6.4484, slack_pen = -0.0000, y_sum_penalty = 0.0000, supervised_loss = 0.0308, 
[epoch 1 | step 40] validation: loss = 0.7016, obj_val = -97.5692, avg_opt_gap = 1.1575, slack_pen = 0.0000, y_sum_penalty = 0.0060, supervised_loss = 0.0173, 
[epoch 1 | step 60] validation: loss = 0.2520, obj_val = -97.1751, avg_opt_gap = 1.5592, slack_pen = 0.0000, y_sum_penalty = 0.0010, supervised_loss = 0.0178, 
[epoch 1 | step 80] validation: loss = 0.4360, obj_val = -98.2315, avg_opt_gap = 0.4911, slack_pen = 0.0000, y_sum_penalty = 0.0030, supervised_loss 