In [1]:
import torch
import torch.nn as nn
import numpy as np
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import sys
import os
import wandb

# plotting
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects

from models import *

# Problem Formulation

We solve the following MIQP problem:
$$
\begin{align}
&\text{minimize } && \mathbf{x}^T \mathbf{x} + \mathbf{p}^T \mathbf{x} \\
&\text{subject to} && \mathbf{x} \le \mathbf{b}, \\
& && \mathbf{1}^T \mathbf{x} \le a, \\
& && \mathbf{x} \le M \mathbf{y}, \\
& && \mathbf{1}^T \mathbf{y} \le 1,
\end{align}
$$
where $\mathbf{x} \in \mathbb{R}^2$ and $\mathbf{y} \in \{0,1\}^2$ are the continuous and binary optimization variables. The parameters include $\mathbf{p}$, $\mathbf{b}$, and $a$.


# Data generation

In [2]:
# Problem setup
nx = 2  # number of continuous decision variables
ny = 2  # number of integer decision variables
data_seed = 18
np.random.seed(data_seed)
torch.manual_seed(data_seed)

p_low, p_high = -30.0, 5.0   # linear term in objective
b_low, b_high = 5.0, 25.0    # RHS of constraint x <= b
a_low, a_high = 10.0, 30.0   # RHS of constraint 1^T x <= a

ntrain = 50000
ntest = 1000

# Generate samples
samples_train = {
    "p": torch.FloatTensor(ntrain, nx).uniform_(p_low, p_high),
    "b": torch.FloatTensor(ntrain, nx).uniform_(b_low, b_high),
    "a": torch.FloatTensor(ntrain, 1).uniform_(a_low, a_high),
}

samples_dev = {
    "p": torch.FloatTensor(ntrain, nx).uniform_(p_low, p_high),
    "b": torch.FloatTensor(ntrain, nx).uniform_(b_low, b_high),
    "a": torch.FloatTensor(ntrain, 1).uniform_(a_low, a_high),
}

samples_test = {
    "p": torch.FloatTensor(ntest, nx).uniform_(p_low, p_high),
    "b": torch.FloatTensor(ntest, nx).uniform_(b_low, b_high),
    "a": torch.FloatTensor(ntest, 1).uniform_(a_low, a_high),
}

# --- Custom dataset class ---
class SampleDataset(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples["p"])

    def __getitem__(self, idx):
        return {
            "p": self.samples["p"][idx],
            "b": self.samples["b"][idx],
            "a": self.samples["a"][idx],
        }

# Create datasets
train_dataset = SampleDataset(samples_train)
dev_dataset = SampleDataset(samples_dev)
test_dataset = SampleDataset(samples_test)

# Slacked QP Formulation

We solve the following MIQP problem:
$$
\begin{align}
&\text{minimize } && \mathbf{x}^\top \mathbf{x} + \mathbf{p}^\top \mathbf{x} + \mathbf{s}^\top \mathbf{s} \\
&\text{subject to} && \mathbf{x} \le \mathbf{b}, \\
& && \mathbf{1}^\top \mathbf{x} \le a, \\
& && \mathbf{x} \le M \mathbf{y} + \mathbf{s}, \\
& && \mathbf{1}^\top \mathbf{y} \le 1, \\
& && \mathbf{s} \ge 0,
\end{align}
$$
where $\mathbf{x} \in \mathbb{R}^2$ and $\mathbf{y} \in \{0,1\}^2$ are the continuous and binary optimization variables. The parameters include $\mathbf{p}$, $\mathbf{b}$, and $a$.

In [3]:
@torch.no_grad()
def _check_shapes(p, b, a, y):
    """    Quick assertion helper. Accepts batched tensors:
    - p: (B,n), b: (B,n), a: (B,1) or (B,), M: (B,n,ny), y: (B,ny)
    """
    B = p.shape[0]
    n = p.shape[1]
    ny = y.shape[1]
    assert b.shape == (B, n)
    assert y.shape == (B, ny)
    assert a.shape in [(B,), (B, 1)]
    return B, n, ny

def solve_qp_with_slacks(layer: CvxpyLayer, p, b, a, y):
    """
    Run the CVXPYLayer in batch. All inputs are torch tensors.
    Returns x, s (each torch tensor with grad).
    """
    _, _, _ = _check_shapes(p, b, a, y)
    if a.ndim == 1:
        a = a.unsqueeze(-1)

    # Device management
    device = p.device
    layer = layer.to(device)
    p = p.to(device)
    b = b.to(device)
    a = a.to(device)
    y = y.to(device)

    x_opt, s_opt = layer(p, b, a, y)
    return x_opt, s_opt

def QP_Layer(nx, ny, penalty="l1", rho1=1.0, bigM = 1e3, **kwargs):
    # Define CVXPY variables and parameters
    x = cp.Variable((nx,))  # continuous decision variables
    y = cp.Parameter((ny,))  # integer decision variables

    p = cp.Parameter((nx,))  # linear term in the objective
    b = cp.Parameter((nx,))  # RHS of the constraint x <= b
    a = cp.Parameter((1,))   # RHS of the constraint 1.T*x <= a
    s = cp.Variable((nx,), nonneg=True)  # slack variables
    
    # Define the QP problem
    if penalty == "l1": # default to l1 penalty
        objective = cp.Minimize(cp.quad_form(x, np.eye(nx)) + p.T @ x + rho1*cp.sum(s))
    elif penalty == "l2": 
        objective = cp.Minimize(cp.quad_form(x, np.eye(nx)) + p.T @ x + rho1*cp.quad_form(s, np.eye(nx)))
    constraints = [
        x <= b,
        sum(x) <= a,
        x <= bigM * y + s,
        s >= 0,
    ]
    problem = cp.Problem(objective, constraints)

    # Create CVXPY layer
    cvxpylayer = CvxpyLayer(problem, parameters=[p, b, a, y], variables=[x, s])
    return cvxpylayer

def _solve_single_miqp(args):
    """Helper function to solve a single MIQP problem."""
    # Redirect stdout and stderr to devnull at the start of each process
    sys.stdout = open(os.devnull, 'w')
    sys.stderr = open(os.devnull, 'w')
    """Helper function to solve a single MIQP problem."""
    i, p_i, b_i, a_i, nx, ny = args
    
    # Variables
    x = cp.Variable(nx)
    y = cp.Variable(ny, boolean=True)
    
    # Objective and constraints
    objective = cp.Minimize(cp.sum_squares(x) + p_i @ x)
    constraints = [
        x <= b_i,
        cp.sum(x) <= a_i,
        cp.sum(y) <= 1,
        x <= 1e3 * y
    ]
    
    # Problem definition
    prob = cp.Problem(objective, constraints)
    
    try:
        prob.solve(solver=cp.GUROBI, verbose=False, OutputFlag=0)
        
        if x.value is not None and y.value is not None:
            return i, x.value, y.value
        else:
            return i, np.zeros(nx), np.zeros(ny)
    except Exception as e:
        print(f"Error solving sample {i}: {e}")
        return i, np.zeros(nx), np.zeros(ny)

@torch.no_grad()
def GUROBI_solve_parallel(p: torch.Tensor, b: torch.Tensor, a: torch.Tensor, max_workers=None):
    """
    Solve MIQP for each sample in the batch using parallel processing.
    """
    device = p.device
    p_np = p.detach().cpu().numpy()
    b_np = b.detach().cpu().numpy()
    a_np = a.detach().cpu().numpy()
    
    if a_np.ndim == 2:
        a_np = a_np.squeeze(-1)
    
    B = p_np.shape[0]
    
    # Prepare arguments for parallel execution
    args_list = [(i, p_np[i], b_np[i], a_np[i], nx, ny) for i in range(B)]
    
    # Preallocate result arrays
    x_results = np.zeros((B, nx))
    y_results = np.zeros((B, ny))
    
    # Use ProcessPoolExecutor for parallel solving
    if max_workers is None:
        max_workers = min(B, mp.cpu_count())
    
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(_solve_single_miqp, args): args[0] for args in args_list}
        
        for future in as_completed(futures):
            i, x_sol, y_sol = future.result()
            x_results[i] = x_sol
            y_results[i] = y_sol
    
    return torch.tensor(x_results).float().to(device), torch.tensor(y_results).float().to(device)

### Supervised learning demo

In [4]:
def train_SL(nn_model, train_loader, test_loader, training_params, device = None, wandb_log = False):
    TRAINING_EPOCHS = training_params['TRAINING_EPOCHS']
    CHECKPOINT_AFTER = training_params['CHECKPOINT_AFTER']
    LEARNING_RATE = training_params['LEARNING_RATE']
    WEIGHT_DECAY = training_params['WEIGHT_DECAY']
    PATIENCE = training_params['PATIENCE']
    # Put all layers in device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nn_model.to(device)

    global_step = 0
    optimizer = torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=int(PATIENCE/2))
    supervised_loss_fn = nn.HuberLoss() 
    best_val_loss = float("inf") # Store best validation 
    epochs_no_improve = 0  # Count epochs with no improvement  
    if wandb_log: wandb.init(
        project=training_params.get("WANDB_PROJECT", "supervised_learning"),
        name=training_params.get("RUN_NAME", None),
        config=training_params
    )    

    for epoch in range(1, TRAINING_EPOCHS+1):
        nn_model.train()
        for batch in train_loader:
            p_batch = batch['p'].to(device); b_batch = batch['b'].to(device); a_batch = batch['a'].to(device)
            B = p_batch.shape[0]
            # ---- Predict y from theta = [p,b,a] ----
            theta = torch.cat([p_batch, b_batch, a_batch], dim=-1)  # (B, 2n+1)
            y_pred = nn_model(theta).float() # (B, ny), hard {0,1}
            # May need it to include a supervised loss function 
            x_solver, y_solver = GUROBI_solve_parallel(p_batch, b_batch, a_batch)
            supervised_loss = supervised_loss_fn(y_pred, y_solver.float())
            loss = supervised_loss
            if wandb_log: wandb.log({
                "train/loss": loss.item(),
                "step": global_step})
            # ---- Backprop ----
            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(nn_model.parameters(), max_norm=1e1)
            optimizer.step()
            global_step += 1
            
            # ---- Logging ----
            if global_step == 1 or (global_step % CHECKPOINT_AFTER) == 0:
                training_loss = loss.item()
                val_loss_total = 0.0
                with torch.no_grad():
                    for val_batch in test_loader:
                        p_batch = val_batch['p'].to(device); b_batch = val_batch['b'].to(device); a_batch = val_batch['a'].to(device)
                        # ---- Predict y from theta = [p,b,a] ----
                        theta = torch.cat([p_batch, b_batch, a_batch], dim=-1)  # (B, 2n+1)
                        y_pred_test = nn_model(theta).float() # (B, ny), hard {0,1}
                        _, y_solver_test = GUROBI_solve_parallel(p_batch, b_batch, a_batch)
                        val_loss_total += supervised_loss_fn(y_pred_test, y_solver_test).item()                           
                    avg_val_loss = val_loss_total / len(test_loader)

                print(f"[epoch {epoch} | step {global_step}] "
                    f"training loss = {training_loss:.4f}, "
                    f"validation loss = {avg_val_loss:.4f}")
                # --- Log losses to wandb ---
                if wandb_log: wandb.log({
                    "val/loss": avg_val_loss,
                    "epoch": epoch})

                # Check if need to update the learning rates
                last_lr = optimizer.param_groups[0]['lr']
                scheduler.step(loss.item())
                current_lr = optimizer.param_groups[0]['lr']
                if current_lr != last_lr:
                    print(f"Learning rate updated: {last_lr:.6f} -> {current_lr:.6f}")
                    last_lr = current_lr

                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= PATIENCE:
                        print("Early stopping triggered!")
                        return
    if wandb_log: wandb.finish()
                

In [5]:
# Create DataLoaders
batch_size = 500
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SL_model = MLPWithSTE(insize=2*nx+1, outsize=ny,
                bias=True,
                linear_map=torch.nn.Linear,
                nonlin=nn.ReLU,
                hsizes=[128] * 2)

training_params = {}
training_params['TRAINING_EPOCHS'] = int(1)
training_params['CHECKPOINT_AFTER'] = int(20)
training_params['LEARNING_RATE'] = 1e-3
training_params['WEIGHT_DECAY'] = 1e-5
training_params['PATIENCE'] = 5

train_SL(SL_model, train_loader, test_loader, training_params, device = device)

[epoch 1 | step 1] training loss = 0.2290, validation loss = 0.2260
[epoch 1 | step 20] training loss = 0.0340, validation loss = 0.0220
[epoch 1 | step 40] training loss = 0.0195, validation loss = 0.0130
[epoch 1 | step 60] training loss = 0.0155, validation loss = 0.0108
[epoch 1 | step 80] training loss = 0.0090, validation loss = 0.0095
[epoch 1 | step 100] training loss = 0.0060, validation loss = 0.0088


### Self-Supervised Learning Demo

In [7]:
# Define some penalty functions that we may use
l1_penalty = lambda s: s.sum(dim=1)
l2_penalty = lambda s: (s**2).sum(dim=1)

def combined_loss_fcn(loss_components, weights):
    """
    Combine multiple loss components with given weights.
    """
    assert len(loss_components) == len(weights), "Number of loss components must match number of weights."
    combined_loss = sum(w * lc for w, lc in zip(weights, loss_components))
    return combined_loss

def train_SSL(nn_model, cvx_layer, sl_model, train_loader, test_loader, training_params,  loss_weights, 
            slack_penalty = l1_penalty, constraint_penalty = torch.relu,
            device = None, wandb_log = False):
    TRAINING_EPOCHS = training_params['TRAINING_EPOCHS']
    CHECKPOINT_AFTER = training_params['CHECKPOINT_AFTER']
    LEARNING_RATE = training_params['LEARNING_RATE']
    WEIGHT_DECAY = training_params['WEIGHT_DECAY']
    PATIENCE = training_params['PATIENCE']
    # Put all layers in device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    nn_model.to(device)
    nn_model.train()
    cvx_layer.to(device)
    sl_model.to(device)
    sl_model.eval()

    # Define weights for loss components
    weights = torch.tensor(loss_weights, device=device)
    weights = weights / weights.sum()
    global_step = 0
    optimizer = torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=int(PATIENCE/2))
    supervised_loss_fn = nn.HuberLoss() 
    best_val_loss = float("inf") # Store best validation 
    epochs_no_improve = 0  # Count epochs with no improvement 

    if wandb_log: wandb.init(
        project=training_params.get("WANDB_PROJECT", "self_supervised_learning"),
        name=training_params.get("RUN_NAME", None),
        config=training_params
    )       

    # Validation for the supervised learning model
    print("Validation for the supervised learning model: ")
    val_loss_total = []; obj_val_total = []; slack_pen_total = []; y_sum_penalty_total = []; supervised_loss_total = []
    with torch.no_grad():
        for val_batch in test_loader:
            p_batch = val_batch['p'].to(device); b_batch = val_batch['b'].to(device); a_batch = val_batch['a'].to(device)
            # ---- Predict y from theta = [p,b,a] ----
            theta = torch.cat([p_batch, b_batch, a_batch], dim=-1)  # (B, 2n+1)
            # prediction from the supervised learning model
            y_pred_hat = sl_model(theta).float() # (B, ny), hard {0,1}
            # ---- Solve convex subproblem given y ----
            # CVXPYLayer supports autograd; keep inputs requiring grad if needed
            x_opt, s_opt = solve_qp_with_slacks(cvx_layer, p_batch, b_batch, a_batch, y_pred_hat)
            # May need it to include a supervised loss function 
            x_solver, y_solver = GUROBI_solve_parallel(p_batch, b_batch, a_batch)
            supervised_loss = supervised_loss_fn(y_pred_hat, y_solver.float())
            obj_val = ((x_opt**2).sum(dim=1) + (p_batch * x_opt).sum(dim=1)).mean()  # (B,)
            # Slack penalty, for constraint violation of the continuous decision variables
            slack_pen = slack_penalty(s_opt).mean()
            # Violation penalty for constraint violation with the integer decision variables
            y_sum_penalty = constraint_penalty(y_pred_hat.sum(dim=1) - 1.0).mean() # (B,)
            # Collect loss values
            obj_val_total.append(obj_val.item())
            slack_pen_total.append(slack_pen.item())
            y_sum_penalty_total.append(y_sum_penalty.item())
            supervised_loss_total.append(supervised_loss.item())

        # Compute the averages
        avg_obj_val = torch.mean(torch.tensor(obj_val_total))
        avg_slack_pen = torch.mean(torch.tensor(slack_pen_total))
        avg_y_sum_penalty = torch.mean(torch.tensor(y_sum_penalty_total))
        avg_supervised_loss = torch.mean(torch.tensor(supervised_loss_total))

        print(
            f"supervised learning model: "
            f"obj_val = {avg_obj_val:.4f}, "
            f"slack_pen = {avg_slack_pen:.4f}, "
            f"y_sum_penalty = {avg_y_sum_penalty:.4f}, "
            f"supervised_loss = {avg_supervised_loss:.4f}, ")    
    print("_"*50)

    for epoch in range(1, TRAINING_EPOCHS+1):
        nn_model.train()
        for batch in train_loader:
            p_batch = batch['p'].to(device); b_batch = batch['b'].to(device); a_batch = batch['a'].to(device)
            B = p_batch.shape[0]
            # ---- Predict y from theta = [p,b,a] ----
            theta = torch.cat([p_batch, b_batch, a_batch], dim=-1)  # (B, 2n+1)
            # prediction from the supervised learning model
            y_pred_hat = sl_model(theta).float() # (B, ny), hard {0,1}
            # construct the concatenated input
            concat_input = torch.cat([theta, y_pred_hat], dim=-1)
            y_pred = nn_model(concat_input).float() # (B, ny), hard {0,1}
            # ---- Solve convex subproblem given y ----
            # CVXPYLayer supports autograd; keep inputs requiring grad if needed
            x_opt, s_opt = solve_qp_with_slacks(cvx_layer, p_batch, b_batch, a_batch, y_pred)
            obj_val = ((x_opt**2).sum(dim=1) + (p_batch * x_opt).sum(dim=1)).mean()  # (B,)
            # Slack penalty, for constraint violation of the continuous decision variables
            slack_pen = slack_penalty(s_opt).mean()
            # Violation penalty for constraint violation with the integer decision variables
            y_sum_penalty = constraint_penalty(y_pred.sum(dim=1) - 1.0).mean()  # (B,)
            # supervised learning loss
            deviation_loss = supervised_loss_fn(y_pred, y_pred_hat.float())
            # Total loss with balanced weights
            loss = combined_loss_fcn([obj_val, slack_pen, y_sum_penalty, deviation_loss], weights)
            if wandb_log: wandb.log({
                "train/combined_loss": loss.item(),
                "train/obj_val": obj_val.item(),
                "train/slack_pen": slack_pen.item(),
                "train/y_sum_penalty": y_sum_penalty.item(),
                "train/deviation_loss": deviation_loss.item(),
                "step": global_step})

            # ---- Backprop ----
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(nn_model.parameters(), max_norm=1e1)
            optimizer.step()         
            global_step += 1

            # ---- Logging ----
            if global_step == 1 or (global_step % CHECKPOINT_AFTER) == 0:
                val_loss_total = []; obj_val_total = []; slack_pen_total = []; y_sum_penalty_total = []; supervised_loss_total = []
                with torch.no_grad():
                    for val_batch in test_loader:
                        p_batch = val_batch['p'].to(device); b_batch = val_batch['b'].to(device); a_batch = val_batch['a'].to(device)
                        # ---- Predict y from theta = [p,b,a] ----
                        theta = torch.cat([p_batch, b_batch, a_batch], dim=-1)  # (B, 2n+1)
                        # prediction from the supervised learning model
                        y_pred_hat = sl_model(theta).float() # (B, ny), hard {0,1}
                        # construct the concatenated input
                        concat_input = torch.cat([theta, y_pred_hat], dim=-1)
                        y_pred_test = nn_model(concat_input).float() # (B, ny), hard {0,1}
                        # ---- Solve convex subproblem given y ----
                        # CVXPYLayer supports autograd; keep inputs requiring grad if needed
                        x_opt, s_opt = solve_qp_with_slacks(cvx_layer, p_batch, b_batch, a_batch, y_pred_test)
                        # May need it to include a supervised loss function 
                        x_solver, y_solver = GUROBI_solve_parallel(p_batch, b_batch, a_batch)
                        supervised_loss = supervised_loss_fn(y_pred_test, y_solver.float())
                        obj_val = ((x_opt**2).sum(dim=1) + (p_batch * x_opt).sum(dim=1)).mean()  # (B,)
                        # Slack penalty, for constraint violation of the continuous decision variables
                        slack_pen = slack_penalty(s_opt).mean()
                        # Violation penalty for constraint violation with the integer decision variables
                        y_sum_penalty = constraint_penalty(y_pred_test.sum(dim=1) - 1.0).mean() # (B,)
                        loss = combined_loss_fcn([obj_val, slack_pen, y_sum_penalty, supervised_loss], weights)
                        # Collect loss values
                        val_loss_total.append(loss.item())       
                        obj_val_total.append(obj_val.item())
                        slack_pen_total.append(slack_pen.item())
                        y_sum_penalty_total.append(y_sum_penalty.item())
                        supervised_loss_total.append(supervised_loss.item())

                    # Compute the averages
                    avg_val_loss = torch.mean(torch.tensor(val_loss_total))
                    avg_obj_val = torch.mean(torch.tensor(obj_val_total))
                    avg_slack_pen = torch.mean(torch.tensor(slack_pen_total))
                    avg_y_sum_penalty = torch.mean(torch.tensor(y_sum_penalty_total))
                    avg_supervised_loss = torch.mean(torch.tensor(supervised_loss_total))

                    print(f"[epoch {epoch} | step {global_step}] "
                        f"validation: loss = {avg_val_loss:.4f}, "
                        f"obj_val = {avg_obj_val:.4f}, "
                        f"slack_pen = {avg_slack_pen:.4f}, "
                        f"y_sum_penalty = {avg_y_sum_penalty:.4f}, "
                        f"supervised_loss = {avg_supervised_loss:.4f}, ")
                    # --- Log losses to wandb ---
                    if wandb_log: wandb.log({
                        "val/avg_loss": avg_val_loss,
                        "val/obj_val": obj_val,
                        "val/slack_pen": slack_pen,
                        "val/y_sum_penalty": y_sum_penalty,
                        "val/supervised_loss": supervised_loss,
                        "epoch": epoch})                

                # Check if need to update the learning rates
                last_lr = optimizer.param_groups[0]['lr']
                scheduler.step(loss.item())
                current_lr = optimizer.param_groups[0]['lr']
                if current_lr != last_lr:
                    print(f"Learning rate updated: {last_lr:.6f} -> {current_lr:.6f}")
                    last_lr = current_lr

                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= PATIENCE:
                        print("Early stopping triggered!")
                        return          

    if wandb_log: wandb.finish()

In [8]:
try:
    device
    SSL_model
except NameError:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SSL_model = MLPWithSTE(insize=2*nx+1+ny, outsize=ny,
                bias=True,
                linear_map=torch.nn.Linear,
                nonlin=nn.ReLU,
                hsizes=[128] * 2)

# Then refine with with self-supervised learning
training_params = {}
training_params['TRAINING_EPOCHS'] = int(5)
training_params['CHECKPOINT_AFTER'] = int(20)
training_params['LEARNING_RATE'] = 1e-4
training_params['WEIGHT_DECAY'] = 1e-5
training_params['PATIENCE'] = 10

slack_weight = 1e3
constraint_weight = 1e4
supervised_weight = 1e3
cvx_layer = QP_Layer(nx=nx, ny=ny, penalty="l1", rho1=slack_weight)
loss_weights = [0.0, slack_weight, constraint_weight, supervised_weight]
train_SSL(SSL_model, cvx_layer, SL_model, train_loader, test_loader, training_params, loss_weights,
    device = device)    

Validation for the supervised learning model: 
supervised learning model: obj_val = -97.7022, slack_pen = 0.0000, y_sum_penalty = 0.0030, supervised_loss = 0.0088, 
__________________________________________________
[epoch 1 | step 1] validation: loss = 0.0236, obj_val = -63.3737, slack_pen = 0.0000, y_sum_penalty = 0.0020, supervised_loss = 0.2630, 
[epoch 1 | step 20] validation: loss = 0.0230, obj_val = -21.0139, slack_pen = -0.0000, y_sum_penalty = 0.0030, supervised_loss = 0.2455, 
[epoch 1 | step 40] validation: loss = 0.0105, obj_val = -63.2487, slack_pen = 0.0000, y_sum_penalty = 0.0000, supervised_loss = 0.1265, 
[epoch 1 | step 60] validation: loss = 0.0103, obj_val = -89.4927, slack_pen = 0.0000, y_sum_penalty = 0.0060, supervised_loss = 0.0640, 
[epoch 1 | step 80] validation: loss = 0.0044, obj_val = -91.4979, slack_pen = 0.0000, y_sum_penalty = 0.0010, supervised_loss = 0.0430, 
[epoch 1 | step 100] validation: loss = 0.0027, obj_val = -92.1142, slack_pen = 0.0000, y_sum_

### Batch Solve Demo

In [None]:
batch = next(iter(train_loader))
p_batch = batch['p']
b_batch = batch['b']
a_batch = batch['a']
B = p_batch.shape[0]
M_batch = torch.ones(B, nx, ny) * 50.0  # Big-M matrix

device = p_batch.device  # use your current device
B, n = p_batch.shape
ny = M_batch.shape[-1]

# 4.1 Build the layer once
cvx_layer = QP_Layer(nx=n, ny=ny, penalty="l1", rho1=1e1)

nn_model = MLPWithSTE(insize=2*nx+1, outsize=ny,
            bias=True,
            linear_map=torch.nn.Linear,
            nonlin=nn.ReLU,
            hsizes=[128] * 4)

# 4.2 Get y from your classifier
# Concatenate inputs
theta = torch.cat([p_batch, b_batch, a_batch], dim=1)  # shape: (B, nx + ny + 1)
# Forward pass through classifier
y_pred_hard = nn_model(theta).float()  # shape: (B, ny)
y_pred_int = y_pred_hard.int()

# 4.3 Solve the convex subproblem given y
x_opt, s_opt = solve_qp_with_slacks(
    cvx_layer, p_batch, b_batch, a_batch, y_pred_hard
)
obj_val = torch.sum(x_opt**2, dim=1) + torch.sum(p_batch * x_opt, dim=1)  # (B,)

# 4.4 Solve with GUROBI for comparison
x_solver, y_solver = GUROBI_solve_parallel(p_batch, b_batch, a_batch)

if True:
    # print("NN Solution x:", x_opt)
    # print("GUROBI Solution x:", x_solver)
    print("Objective values:", obj_val)
    print("Slack:", s_opt)
