In [1]:
import copy
import os
import pickle
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.autograd.profiler as profiler
import numpy as np

from QP_problem import SimpleProblem, OriginalQPProblem, QPProblemVaryingG, QPProblemVaryingGbd

In [2]:
def create_QP_dataset(num_var, num_ineq, num_eq, num_examples):
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    X = np.random.uniform(-1, 1, size=(num_examples, num_eq))
    G = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))
    h = np.sum(np.abs(G@np.linalg.pinv(A)), axis=1)

    problem = OriginalQPProblem(Q, p, A, G, X, h)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/original/random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_varying_G_dataset(num_var, num_ineq, num_eq, num_examples, num_varying_rows):
    """Creates a modified QP data set that differs in the inequality constraint matrix, instead of the RHS variables.
    """
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    # X is the same for all samples:
    b = np.random.uniform(-1, 1, size=(num_eq))
    G_base = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))
    # TODO: Can we keep h constant, if we are varying G?
    d = np.sum(np.abs(G_base@np.linalg.pinv(A)), axis=1)

    G_list = []
    # For each sample, create a different inequality constraint matrix
    for _ in range(num_examples):
        G_sample = G_base.copy()
        # Vary the first n rows, (specified by num_varying_rows).
        G_sample[:num_varying_rows, :] = np.random.normal(loc=0, scale=1., size=(1, num_var))
        G_list.append(G_sample)

    G = np.array(G_list)
    problem = QPProblemVaryingG(Q=Q, p=p, A=A, G_base=G_base, G_varying=G, b=b, d=d, n_varying_rows=num_varying_rows)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/modified/MODIFIED_random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_varying_G_b_d_dataset(num_var, num_ineq, num_eq, num_examples, num_varying_rows):
    """Creates a modified QP data set that differs in the inequality constraint matrix, instead of the RHS variables.
    """
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    # X is the same for all samples:
    B = np.random.uniform(-1, 1, size=(num_examples, num_eq))
    G_base = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))

    G_list = []
    # For each sample, create a different inequality constraint matrix
    for _ in range(num_examples):
        G_sample = G_base.copy()
        # Vary the first n rows, (specified by num_varying_rows).
        G_sample[:num_varying_rows, :] = np.random.normal(loc=0, scale=1., size=(num_varying_rows, num_var))
        G_list.append(G_sample)
    
    # Create H matrix for each example
    D_list = []
    for Gi in G_list:
        d = np.sum(np.abs(Gi @ np.linalg.pinv(A)), axis=1)  # Compute bounds for all inequalities
        D_list.append(d)  # Resulting shape will be (num_ineq,)

    G = np.array(G_list)
    D = np.stack(D_list, axis=0)  # Shape (num_examples, num_ineq)
    problem = QPProblemVaryingGbd(Q=Q, p=p, A=A, G_base=G_base, G_varying=G, b=B, d=D, n_varying_rows=num_varying_rows)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/modified/MODIFIED_random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_scaled_QP_problem(num_var, num_ineq, num_eq, num_examples, scale='normal'):
    if scale == 'normal':
        obj_scale = 1
        var_scale = 1
        rhs_scale = 1
    elif scale == 'large':
        obj_scale = 1e9
        var_scale = 1e3
        rhs_scale = 1e3

    np.random.seed(17)

    Q = np.diag(np.random.random(num_var)) * obj_scale
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=var_scale, size=(num_eq, num_var))
    X = np.random.uniform(-rhs_scale, rhs_scale, size=(num_examples, num_eq))
    G = np.random.normal(loc=0, scale=var_scale, size=(num_ineq, num_var))
    h = np.sum(np.abs(G@np.linalg.pinv(A)), axis=1)

    problem = OriginalQPProblem(Q, p, A, G, X, h)
    problem.calc_Y()
    print(len(problem.Y))

    # with open("./QP_data/original/random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        # pickle.dump(problem, f)
    
    return problem

In [3]:
DTYPE = torch.float64
DEVICE = torch.device="cpu"
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(42)
print(f"Running on {DEVICE}")

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, eq_cm, ineq_cm, eq_rhs, ineq_rhs):
        self.x = x
        self.eq_cm = eq_cm 
        self.ineq_cm = ineq_cm
        self.eq_rhs = eq_rhs
        self.ineq_rhs = ineq_rhs
        self._index = 0  # Internal index for tracking iteration

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # Return a tuple of input and target for the given index
        #! Change per data set.
        # return self.x[idx], self.eq_cm[idx], self.ineq_cm[idx], self.eq_rhs[idx], self.ineq_rhs[idx]
        return self.x[idx]

        
class PrimalDualTrainer():

    def __init__(self, data, args, save_dir):
        """_summary_

        Args:
            data (_type_): _description_
            args (_type_): _description_
            save_dir (_type_): _description_
            problem_type (str, optional): Either "GEP" or "Benchmark". Defaults to "GEP".
            log (bool, optional): _description_. Defaults to True.
        """

        print(f"X dim: {data.xdim}")
        print(f"Y dim: {data.ydim}")

        print(f"Size of mu: {data.nineq}")
        print(f"Size of lambda: {data.neq}")

        self.data = data
        self.args = args
        self.save_dir = save_dir

        self.outer_iterations = args["outer_iterations"]
        self.inner_iterations = args["inner_iterations"]
        self.tau = args["tau"]
        self.rho = args["rho"]
        self.rho_max = args["rho_max"]
        self.alpha = args["alpha"]
        self.batch_size = args["batch_size"]
        self.hidden_sizes = args["hidden_sizes"]

        self.primal_lr = args["primal_lr"]
        self.dual_lr = args["dual_lr"]
        self.decay = args["decay"]
        self.patience = args["patience"]
        
        # for logging
        self.step = 0

        X = data.X
        eq_cm = data.eq_cm
        ineq_cm = data.ineq_cm
        eq_rhs = data.eq_rhs
        ineq_rhs = data.ineq_rhs

        train = data.train_indices
        valid = data.valid_indices
        test = data.test_indices

        # Traning data in a data set
        #! Vary per experiment
        # self.train_dataset = CustomDataset(X[train].to(DEVICE), eq_cm[train], ineq_cm[train], eq_rhs[train], ineq_rhs[train])
        self.train_dataset = CustomDataset(X[train].to(DEVICE), eq_cm, ineq_cm, eq_rhs[train], ineq_rhs)
        self.valid_dataset = CustomDataset(X[valid].to(DEVICE), eq_cm, ineq_cm, eq_rhs[valid], ineq_rhs)
        self.test_dataset = CustomDataset(X[test].to(DEVICE), eq_cm, ineq_cm, eq_rhs[test], ineq_rhs)

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=1000, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=1000, shuffle=False)

        self.primal_net = PrimalNet(self.data, self.hidden_sizes).to(dtype=DTYPE, device=DEVICE)
        self.dual_net = DualNet(self.data, self.hidden_sizes, self.data.nineq, self.data.neq).to(dtype=DTYPE, device=DEVICE)

        self.primal_optim = torch.optim.Adam(self.primal_net.parameters(), lr=self.primal_lr)
        self.dual_optim = torch.optim.Adam(self.dual_net.parameters(), lr=self.dual_lr)

        # Add schedulers
        self.primal_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.primal_optim, mode='min', factor=self.decay, patience=self.patience
        )
        self.dual_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.dual_optim, mode='min', factor=self.decay, patience=self.patience
        )

    def train_PDL(self,):
        try:
            prev_v_k = 0
            for k in range(self.outer_iterations):
                begin_time = time.time()
                epoch_stats = {}
                frozen_dual_net = copy.deepcopy(self.dual_net)
                # self.logger.log_rho_vk(self.rho, prev_v_k, self.step)
                for l1 in range(self.inner_iterations):
                    self.step += 1
                    # Update primal net using primal loss
                    self.primal_net.train()

                    # Accumulate training loss over all batches
                    for Xtrain in self.train_loader:
                        self.primal_optim.zero_grad()
                        y = self.primal_net(Xtrain, Xtrain, self.train_dataset.ineq_rhs)
                        with torch.no_grad():
                            mu, lamb = frozen_dual_net(Xtrain, self.train_dataset.eq_cm)
                        batch_loss = self.primal_loss(y, self.train_dataset.eq_cm, self.train_dataset.ineq_cm, Xtrain, self.train_dataset.ineq_rhs, mu, lamb).mean()
                        batch_loss.backward()
                        self.primal_optim.step()

                    # Evaluate validation loss every epoch, and update learning rate
                    with torch.no_grad():
                        self.primal_net.eval()
                        frozen_dual_net.eval()
                        val_loss = 0
                        for Xvalid in self.valid_loader:
                            # for Xvalid, valid_eq_cm, valid_ineq_cm, valid_eq_rhs, valid_ineq_rhs in self.valid_loader:
                            y = self.primal_net(Xvalid, Xvalid, self.valid_dataset.ineq_rhs)
                            mu, lamb = frozen_dual_net(Xvalid, self.valid_dataset.eq_cm)
                            val_loss += self.primal_loss(y, self.valid_dataset.eq_cm, self.valid_dataset.ineq_cm, Xvalid, self.valid_dataset.ineq_rhs, mu, lamb).sum()
                        val_loss /= len(self.valid_loader)
                        # Normalize by rho, so that the schedular still works correctly if rho is increased
                        self.primal_scheduler.step(torch.sign(val_loss) * (torch.abs(val_loss) / self.rho))
                
                with torch.no_grad():
                    # Copy primal net into frozen primal net
                    frozen_primal_net = copy.deepcopy(self.primal_net)

                    # Calculate v_k
                    y = frozen_primal_net(self.train_dataset.x, self.train_dataset.eq_rhs, self.train_dataset.ineq_rhs)
                    mu_k, lamb_k = frozen_dual_net(self.train_dataset.x, self.train_dataset.eq_cm)
                    v_k = self.violation(y, self.train_dataset.eq_cm, self.train_dataset.ineq_cm, self.train_dataset.eq_rhs, self.train_dataset.ineq_rhs, mu_k)

                for l in range(self.inner_iterations):
                    self.step += 1
                    # Update dual net using dual loss
                    self.dual_net.train()
                    frozen_primal_net.train()
                    for Xtrain in self.train_loader:
                        self.dual_optim.zero_grad()
                        mu, lamb = self.dual_net(Xtrain, self.train_dataset.eq_cm)
                        with torch.no_grad():
                            mu_k, lamb_k = frozen_dual_net(Xtrain, self.train_dataset.eq_cm)
                            y = frozen_primal_net(Xtrain, Xtrain, self.train_dataset.ineq_rhs)
                        # ! Test other loss!
                        batch_loss = self.dual_loss(y, self.train_dataset.eq_cm, self.train_dataset.ineq_cm, Xtrain, self.train_dataset.ineq_rhs, mu, lamb, mu_k, lamb_k).mean()
                        # batch_loss = self.dual_loss_changed(y, train_eq_cm, train_ineq_cm, train_eq_rhs, train_ineq_rhs, mu, lamb, mu_k, lamb_k).mean()
                        batch_loss.backward()
                        self.dual_optim.step()
                    
                    # with torch.no_grad():
                    #     # Logg training loss:
                    #     self.logger.log_loss(batch_loss, "dual", self.step)
                    #     self.logger.log_train(self.data, primal_net=frozen_primal_net, dual_net=self.dual_net, step=self.step)

                    # Evaluate validation loss every epoch, and update learning rate
                    # TODO! Does scheduler correctly decrease LR when rho is increased, if the training set is small?
                    with torch.no_grad():
                        frozen_primal_net.eval()
                        self.dual_net.eval()
                        val_loss = 0
                        for Xvalid in self.valid_loader:
                            y = frozen_primal_net(Xvalid, Xvalid, self.valid_dataset.ineq_rhs)
                            mu_valid, lamb_valid = self.dual_net(Xvalid, self.valid_dataset.eq_cm)
                            mu_k_valid, lamb_k_valid = frozen_dual_net(Xvalid, self.valid_dataset.eq_cm)
                            val_loss += self.dual_loss(y, self.valid_dataset.eq_cm, self.valid_dataset.ineq_cm, Xvalid, self.valid_dataset.ineq_rhs, mu_valid, lamb_valid, mu_k_valid, lamb_k_valid).sum()
                        val_loss /= len(self.valid_loader)
                    # Normalize by rho, so that the schedular still works correctly if rho is increased
                    self.dual_scheduler.step(torch.sign(val_loss) * (torch.abs(val_loss) / self.rho))

                end_time = time.time()
                stats = epoch_stats
                print("-"*40)
                print(f"Epoch {k} done. Time taken: {end_time - begin_time}. Rho: {self.rho}. Primal LR: {self.primal_optim.param_groups[0]['lr']}, Dual LR: {self.dual_optim.param_groups[0]['lr']}")

                # Update rho from the second iteration onward.
                if k > 0 and v_k > self.tau * prev_v_k:
                    self.rho = np.min([self.alpha * self.rho, self.rho_max])

                prev_v_k = v_k
            
                print(f"Validation set evaluate:")
                with torch.no_grad():
                    self.primal_net.eval()
                    self.dual_net.eval()
                    self.evaluate(self.valid_loader, self.valid_dataset, self.primal_net, self.dual_net)
                    
            print("-"*40)
            print(f"Test set evaluate:")
            with torch.no_grad():
                self.primal_net.eval()
                self.dual_net.eval()
                self.evaluate(self.test_loader, self.test_dataset, self.primal_net, self.dual_net)
        
        except Exception as e:
            print(e, flush=True)
            # # Ensure writer is closed even if an exception occurs
            # if self.logger:
            #     self.logger.close()
            raise

        with open(os.path.join(self.save_dir, 'stats.dict'), 'wb') as f:
            pickle.dump(stats, f)
        with open(os.path.join(self.save_dir, 'primal_net.dict'), 'wb') as f:
            torch.save(self.primal_net.state_dict(), f)
        with open(os.path.join(self.save_dir, 'dual_net.dict'), 'wb') as f:
            torch.save(self.dual_net.state_dict(), f)

        return self.primal_net, self.dual_net, stats

    def evaluate(self, loader, dataset, primal_net, dual_net):        
        obj_values = []
        primal_losses = []
        ineq_max_vals = []
        ineq_mean_vals = []
        eq_max_vals = []
        eq_mean_vals = []

        for X in loader:

            # Forward pass through networks
            Y = primal_net(X, dataset.eq_rhs, dataset.ineq_rhs)
            mu, lamb = dual_net(X, dataset.eq_cm)

            # Compute and store metrics
            obj_values.append(self.data.obj_fn(Y).detach().cpu().numpy())
            primal_losses.append(self.primal_loss(Y, dataset.eq_cm, dataset.ineq_cm, X, dataset.ineq_rhs, mu, lamb).detach().cpu().numpy())
            ineq_dist = self.data.ineq_dist(Y, dataset.ineq_cm, dataset.ineq_rhs)
            eq_resid = self.data.eq_resid(Y, dataset.eq_cm, X)

            ineq_max_vals.append(torch.max(ineq_dist, dim=1)[0].detach().cpu().numpy())
            ineq_mean_vals.append(torch.mean(ineq_dist, dim=1).detach().cpu().numpy())
            eq_max_vals.append(torch.max(torch.abs(eq_resid), dim=1)[0].detach().cpu().numpy())
            eq_mean_vals.append(torch.mean(torch.abs(eq_resid), dim=1).detach().cpu().numpy())

        # Convert lists to arrays for easier handling
        obj_values = np.concatenate(obj_values)
        primal_losses = np.concatenate(primal_losses)
        ineq_max_vals = np.concatenate(ineq_max_vals)
        ineq_mean_vals = np.concatenate(ineq_mean_vals)
        eq_max_vals = np.concatenate(eq_max_vals)
        eq_mean_vals = np.concatenate(eq_mean_vals)

        # Print aggregated statistics
        print(f"Obj: {np.mean(obj_values)}")
        print(f"Primal Loss: {np.mean(primal_losses)}")
        print(f"Ineq max: {np.mean(ineq_max_vals)}")
        print(f"Ineq mean: {np.mean(ineq_mean_vals)}")
        print(f"Eq max: {np.mean(eq_max_vals)}")
        print(f"Eq mean: {np.mean(eq_mean_vals)}")



    def primal_loss(self, y, eq_cm, ineq_cm, eq_rhs, ineq_rhs, mu, lamb):
        obj = self.data.obj_fn(y)
        
        # g(y)
        ineq = self.data.ineq_resid(y, ineq_cm, ineq_rhs)
        # h(y)
        eq = self.data.eq_resid(y, eq_cm, eq_rhs)

        # ! Clamp mu?
        # Element-wise clamping of mu_i when g_i (ineq) is negative
        # mu = torch.where(ineq < 0, torch.zeros_like(mu), mu)
        # ! Clamp ineq_resid?
        # ineq = ineq.clamp(min=0)

        lagrange_ineq = torch.sum(mu * ineq, dim=1)  # Shape (batch_size,)

        lagrange_eq = torch.sum(lamb * eq, dim=1)   # Shape (batch_size,)

        violation_ineq = torch.sum(torch.maximum(ineq, torch.zeros_like(ineq)) ** 2, dim=1)
        violation_eq = torch.sum(eq ** 2, dim=1)
        penalty = self.rho/2 * (violation_ineq + violation_eq)

        loss = (obj + (lagrange_ineq + lagrange_eq + penalty))

        return loss

    def dual_loss(self, y, eq_cm, ineq_cm, eq_rhs, ineq_rhs, mu, lamb, mu_k, lamb_k):
        # mu = [batch, g]
        # lamb = [batch, h]

        # g(y)
        ineq = self.data.ineq_resid(y, ineq_cm, ineq_rhs) # [batch, g]
        # h(y)
        eq = self.data.eq_resid(y, eq_cm, eq_rhs)   # [batch, h]

        #! From 2nd PDL paper, fix to 1e-1, not rho
        target_mu = torch.maximum(mu_k + self.rho * ineq, torch.zeros_like(ineq))
        # target_mu = torch.maximum(mu_k + 1e-1 * ineq, torch.zeros_like(ineq))

        dual_resid_ineq = mu - target_mu # [batch, g]

        dual_resid_ineq = torch.norm(dual_resid_ineq, dim=1)  # [batch]

        # Compute the dual residuals for equality constraints
        #! From 2nd PDL paper, fix to 1e-1, not rho
        dual_resid_eq = lamb - (lamb_k + self.rho * eq)
        # dual_resid_eq = lamb - (lamb_k + 1e-1 * eq)
        dual_resid_eq = torch.norm(dual_resid_eq, dim=1)  # Norm along constraint dimension

        loss = (dual_resid_ineq + dual_resid_eq)

        return loss
    
    def dual_loss_changed(self, y, eq_cm, ineq_cm, eq_rhs, ineq_rhs, mu, lamb, mu_k, lamb_k):
        #! We maximize the dual obj func, so to use it in the loss, take the negation.
        dual_obj = -self.data.dual_obj_fn(eq_rhs, ineq_rhs, mu, lamb)

        #! Enforced with ReLU.
        # ineq = self.data.dual_ineq_resid(mu, lamb)

        eq = self.data.dual_eq_resid(mu, lamb, eq_cm, ineq_cm)
        # Lagrange multiplier becomes y
        lagrange_eq = torch.sum(y * eq, dim=1)

        violation_eq = torch.sum(eq ** 2, dim=1)

        penalty = self.rho/2 * violation_eq

        loss = dual_obj + lagrange_eq + penalty
        # loss = dual_obj + penalty

        return loss

    def violation(self, y, eq_cm, ineq_cm, eq_rhs, ineq_rhs, mu_k):
        # Calculate the equality constraint function h_x(y)
        eq = self.data.eq_resid(y, eq_cm, eq_rhs)  # Assume shape (num_samples, n_eq)
        
        # Calculate the infinity norm of h_x(y)
        eq_inf_norm = torch.abs(eq).max(dim=1).values  # Shape: (num_samples,)

        # Calculate the inequality constraint function g_x(y)
        ineq = self.data.ineq_resid(y, ineq_cm, ineq_rhs)  # Assume shape (num_samples, n_ineq)
        
        # Calculate sigma_x(y) for each inequality constraint
        sigma_y = torch.maximum(ineq, -mu_k / self.rho)  # Element-wise max
        
        # Calculate the infinity norm of sigma_x(y)
        sigma_y_inf_norm = torch.abs(sigma_y).max(dim=1).values  # Shape: (num_samples,)

        # Compute v_k as the maximum of the two norms
        v_k = torch.maximum(eq_inf_norm, sigma_y_inf_norm)  # Shape: (num_samples,)
        
        return v_k.max().item()

class PrimalNet(nn.Module):
    def __init__(self, data, hidden_sizes):
        super().__init__()
        self._data = data
        self._hidden_sizes = hidden_sizes
        
        # Create the list of layer sizes
        layer_sizes = [data.xdim] + self._hidden_sizes + [data.ydim]
        layers = []

        # Create layers dynamically based on the provided hidden_sizes
        for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            if out_size != data.ydim:  # Add ReLU activation for hidden layers only
                layers.append(nn.ReLU())

        # Initialize all layers
        for layer in layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)

        self.net = nn.Sequential(*layers)
    
    def forward(self, x, eq_rhs, ineq_rhs):
        return self.net(x)

class DualNet(nn.Module):
    def __init__(self, data, hidden_sizes, mu_size, lamb_size):
        super().__init__()
        self._data = data
        self._hidden_sizes = hidden_sizes
        self._mu_size = mu_size
        self._lamb_size = lamb_size

        # Create the list of layer sizes
        layer_sizes = [data.xdim] + self._hidden_sizes
        # layer_sizes = [2*data.xdim + 1000] + self._hidden_sizes
        layers = []
        # Create layers dynamically based on the provided hidden_sizes
        for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(nn.ReLU())

        # Initialize all layers
        for layer in layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)

        # Add the output layer
        self.out_layer = nn.Linear(self._hidden_sizes[-1], self._mu_size + self._lamb_size)
        nn.init.zeros_(self.out_layer.weight)  # Initialize output layer weights to 0
        nn.init.zeros_(self.out_layer.bias)    # Initialize output layer biases to 0
        layers.append(self.out_layer)

        self.net = nn.Sequential(*layers)
    
    def forward(self, x, *args):
        out = self.net(x)
        #! ReLU to enforce nonnegativity in mu. Test with it.
        #! Does this work with zero initialization?
        # out_mu = torch.relu(out[:, :self._mu_size])
        out_mu = out[:, :self._mu_size]
        out_lamb = out[:, self._mu_size:]
        return out_mu, out_lamb

Running on cpu


Main Script:

In [4]:
num_var = 100
num_ineq = 50
num_eq = 50
num_examples = 10000

save_dir = "benchmark_experiment_output"

data = create_QP_dataset(num_var, num_ineq, num_eq, num_examples)

running osqp
10000


In [5]:
args = {
    "outer_iterations": 10,
    "inner_iterations": 500,
    "tau": 0.8,
    "rho": 0.5,
    "rho_max": 5000,
    "alpha": 10,
    "batch_size": 100,
    "hidden_sizes": [500, 500],
    "primal_lr": 1e-4,
    "dual_lr": 1e-4,
    "decay": 0.99,
    "patience": 10,
    "corrEps": 1e-4,
    "train": 0.8,
    "valid": 0.1,
    "test": 0.1
}

trainer = PrimalDualTrainer(data, args, save_dir)
# with profiler.profile(record_shapes=True) as prof:
primal_net, dual_net, stats = trainer.train_PDL()
    
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

X dim: 50
Y dim: 100
Size of mu: 50
Size of lambda: 50
----------------------------------------
Epoch 0 done. Time taken: 733.6000733375549. Rho: 0.5. Primal LR: 0.0001, Dual LR: 6.491026283684025e-05
Validation set evaluate:
Obj: -16.738317089384445
Primal Loss: -14.247675130564444
Ineq max: 0.4433426452503992
Ineq mean: 0.04646640769533794
Eq max: 0.8352204685184776
Eq mean: 0.1757246383265363
----------------------------------------
Epoch 1 done. Time taken: 694.607549905777. Rho: 0.5. Primal LR: 6.361854860638712e-05, Dual LR: 4.255901233886549e-05
Validation set evaluate:
Obj: -15.01993373317896
Primal Loss: -14.908558875539267
Ineq max: 0.04429696080577451
Ineq mean: 0.0033768584008977746
Eq max: 0.07430474929410694
Eq mean: 0.022267517506934432
----------------------------------------
Epoch 2 done. Time taken: 708.927197933197. Rho: 0.5. Primal LR: 4.04731972678324e-05, Dual LR: 2.8186069554046354e-05
Validation set evaluate:
Obj: -15.093072320725486
Primal Loss: -14.88708394631