In [106]:
import copy
import os
import pickle
import time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.autograd.profiler as profiler
import numpy as np

In [142]:
import torch
import numpy as np
import osqp
from scipy.sparse import csc_matrix

import time

from abc import ABC, abstractmethod

torch.set_default_dtype(torch.float64)

class SimpleProblem(ABC):
    """
    minimize_y 1/2 * y^T Q y + p^Ty
    s.t.       Ay =  b
               Gy <= d
    """

    def __init__(self, Q, p, A, G, b, d, valid_frac=0.0833, test_frac=0.0833):
        self._Q = torch.tensor(Q)
        self._p = torch.tensor(p)
        self._A = torch.tensor(A)
        self._G = torch.tensor(G)
        self._b = torch.tensor(b) # equality RHS
        self._d = torch.tensor(d) # inequality RHS

        self._eq_cm = self._A
        self._ineq_cm = self._G
        self._eq_rhs = self._b
        self._ineq_rhs = self._d

        self._valid_frac = valid_frac
        self._test_frac = test_frac

        self._Y = None
        self._ydim = Q.shape[0]

        ### For Pytorch
        self._device = None

        #! Implement in child!
        self._X = None
        self._num = None
        self._neq = None
        self._nineq = None
        self._xdim = None


    ##### ABSTRACT METHODS #####

    @abstractmethod
    def eq_resid(self, X, Y):
        raise NotImplementedError
    
    @abstractmethod
    def ineq_resid(self, X, Y):
        raise NotImplementedError
    
    @abstractmethod
    def opt_solve(self, X, solver_type="osqp", tol=1e-4):
        raise NotImplementedError

    @abstractmethod
    def calc_Y(self):
        raise NotImplementedError


    def __str__(self):
        return "SimpleProblem-{}-{}-{}-{}".format(
            str(self.ydim), str(self.nineq), str(self.neq), str(self.num)
        )
    
    ##### REG METHODS #####

    @property
    def eq_cm(self):
        return self._eq_cm

    @property
    def ineq_cm(self):
        return self._ineq_cm
    
    @property
    def eq_rhs(self):
        return self._eq_rhs
    
    @property
    def ineq_rhs(self):
        return self._ineq_rhs

    @property
    def Q(self):
        return self._Q

    @property
    def p(self):
        return self._p

    @property
    def A(self):
        return self._A

    @property
    def G(self):
        return self._G

    @property
    def b(self):
        return self._b

    @property
    def d(self):
        return self._d

    @property
    def X(self):
        return self._X

    @property
    def Y(self):
        return self._Y

    @property
    def partial_vars(self):
        return self._partial_vars

    @property
    def other_vars(self):
        return self._other_vars

    @property
    def partial_unknown_vars(self):
        return self._partial_vars

    @property
    def Q_np(self):
        return self.Q.detach().cpu().numpy()

    @property
    def p_np(self):
        return self.p.detach().cpu().numpy()

    @property
    def A_np(self):
        return self.A.detach().cpu().numpy()

    @property
    def G_np(self):
        return self.G.detach().cpu().numpy()

    @property
    def b_np(self):
        return self.b.detach().cpu().numpy()

    @property
    def d_np(self):
        return self.d.detach().cpu().numpy()

    @property
    def X_np(self):
        return self.X.detach().cpu().numpy()

    @property
    def Y_np(self):
        return self.Y.detach().cpu().numpy()

    @property
    def xdim(self):
        return self._xdim

    @property
    def ydim(self):
        return self._ydim

    @property
    def num(self):
        return self._num

    @property
    def neq(self):
        return self._neq

    @property
    def nineq(self):
        return self._nineq

    @property
    def nknowns(self):
        return self._nknowns

    @property
    def valid_frac(self):
        return self._valid_frac

    @property
    def test_frac(self):
        return self._test_frac

    @property
    def train_frac(self):
        return 1 - self.valid_frac - self.test_frac
    
    @property
    def train_indices(self):
        return list(range(int(self.num * self.train_frac)))
    
    @property
    def valid_indices(self):
        return list(range(int(self.num * self.train_frac), int(self.num * (self.train_frac + self.valid_frac))))
    
    @property
    def test_indices(self):
        return list(range(int(self.num * (self.train_frac + self.valid_frac)), self.num))

    @property
    def trainX(self):
        return self.X[: int(self.num * self.train_frac)]

    @property
    def validX(self):
        return self.X[
            int(self.num * self.train_frac) : int(
                self.num * (self.train_frac + self.valid_frac)
            )
        ]

    @property
    def testX(self):
        return self.X[int(self.num * (self.train_frac + self.valid_frac)) :]

    @property
    def trainY(self):
        return self.Y[: int(self.num * self.train_frac)]

    @property
    def validY(self):
        return self.Y[
            int(self.num * self.train_frac) : int(
                self.num * (self.train_frac + self.valid_frac)
            )
        ]

    @property
    def testY(self):
        return self.Y[int(self.num * (self.train_frac + self.valid_frac)) :]

    @property
    def device(self):
        return self._device

    def obj_fn(self, Y):
        return (0.5 * (Y @ self.Q) * Y +  self.p * Y).sum(dim=1)

    def ineq_dist(self, X, Y):
        resids = self.ineq_resid(X, Y)
        return torch.clamp(resids, 0)

    def eq_grad(self, X, Y):
        return 2 * (Y @ self.A.T - X) @ self.A

    def ineq_grad(self, X, Y):
        ineq_dist = self.ineq_dist(X, Y)
        return 2 * ineq_dist @ self.G

    def ineq_partial_grad(self, X, Y):
        G_effective = self.G[:, self.partial_vars] - self.G[:, self.other_vars] @ (
            self._A_other_inv @ self._A_partial
        )
        h_effective = self.h - (X @ self._A_other_inv.T) @ self.G[:, self.other_vars].T
        grad = (
            2
            * torch.clamp(Y[:, self.partial_vars] @ G_effective.T - h_effective, 0)
            @ G_effective
        )
        Y = torch.zeros(X.shape[0], self.ydim, device=self.device)
        Y[:, self.partial_vars] = grad
        Y[:, self.other_vars] = -(grad @ self._A_partial.T) @ self._A_other_inv.T
        return Y

    # Processes intermediate neural network output
    def process_output(self, X, Y):
        return Y

    # Solves for the full set of variables
    def complete_partial(self, X, Z):
        Y = torch.zeros(X.shape[0], self.ydim, device=self.device)
        Y[:, self.partial_vars] = Z
        Y[:, self.other_vars] = (X - Z @ self._A_partial.T) @ self._A_other_inv.T
        return Y

class OriginalQPProblem(SimpleProblem):
    def __init__(self, Q, p, A, G, b, d, valid_frac=0.1, test_frac=0.1):
        super().__init__(Q, p, A, G, b, d, valid_frac, test_frac)

        self._X = self._b
        self._num = self._X.shape[0]
        self._neq = self._A.shape[0]
        self._nineq = self._G.shape[0]
        self._xdim = self._X.shape[1]

        self.A_transpose = self._A.T
        self.G_transpose = self._G.T


    def eq_resid(self, X, Y):
        # Here, X is the RHS of the equality constraints
        return X - Y @ self.A.T

    def ineq_resid(self, X, Y):
        return Y @ self.G.T - self.d
    
    def ineq_dist(self, X, Y):
        resids = self.ineq_resid(X, Y)
        return torch.clamp(resids, 0)
    
    def opt_solve(self, X, solver_type="osqp", tol=1e-4):
        if solver_type == "osqp":
            print("running osqp")
            Q, p, A, G, d = self.Q_np, self.p_np, self.A_np, self.G_np, self.d_np
            X_np = X.detach().cpu().numpy()
            Y = []
            total_time = 0
            for Xi in X_np:
                solver = osqp.OSQP()
                my_A = np.vstack([A, G])
                my_l = np.hstack([Xi, -np.ones(d.shape[0]) * np.inf])
                my_u = np.hstack([Xi, d])
                solver.setup(
                    P=csc_matrix(Q),
                    q=p,
                    A=csc_matrix(my_A),
                    l=my_l,
                    u=my_u,
                    verbose=False,
                    eps_prim_inf=tol,
                )
                start_time = time.time()
                results = solver.solve()
                end_time = time.time()

                total_time += end_time - start_time
                if results.info.status == "solved":
                    Y.append(results.x)
                else:
                    Y.append(np.ones(self.ydim) * np.nan)

                sols = np.array(Y)
                parallel_time = total_time / len(X_np)
        else:
            raise NotImplementedError

        return sols, total_time, parallel_time

    def calc_Y(self):
        Y = self.opt_solve(self.X)[0]
        feas_mask = ~np.isnan(Y).all(axis=1)
        self._num = feas_mask.sum()
        self._X = self._X[feas_mask]
        self._Y = torch.tensor(Y[feas_mask])
        return Y

class QPProblemVaryingG(SimpleProblem):
    def __init__(self, X, Q, p, A, G, b, d, row_indices, col_indices, valid_frac=0.1, test_frac=0.1):
        super().__init__(Q, p, A, G, b, d, valid_frac, test_frac)
        # X are the varying values of G
        self._X = X
        self._num = self._X.shape[0]
        self._neq = self._A.shape[0]
        self._nineq = self._G.shape[0]
        self._xdim = self._X.shape[1]

        self.row_indices = row_indices
        self.col_indices = col_indices

    def eq_resid(self, X, Y):
        # Here, X is part of the inequality constraint matrix. So we don't use it
        return self.b - Y @ self.A.T

    def ineq_resid(self, X, Y):
        # Here, X is part of the inequality constraint matrix. So, we need to plug X into the inequality constraint matrix.
        G = self.G.expand(X.shape[0], -1, -1).clone()
        G[:, self.row_indices, self.col_indices] = X
        return torch.bmm(G, Y.unsqueeze(-1)).squeeze(-1) - self.d
    
    def ineq_dist(self, X, Y):
        resids = self.ineq_resid(X, Y)
        return torch.clamp(resids, 0)
    
    def opt_solve(self, X, solver_type="osqp", tol=1e-4):
        """We change op_solve so that the varying G matrices are taken from the input X.
        """
        if solver_type == "osqp":
            print("running osqp")
            Q, p, b, d = self.Q_np, self.p_np, self.b_np, self.d_np
            G = self.G.expand(X.shape[0], -1, -1).clone()
            G[:, self.row_indices, self.col_indices] = X
            G = G
            A = self.A_np
            Y = []
            total_time = 0

            for Gi in G:
                
                solver = osqp.OSQP()
                my_A = np.vstack([A, Gi])
                my_l = np.hstack([b, -np.ones(d.shape[0]) * np.inf])
                my_u = np.hstack([b, d])
                solver.setup(
                    P=csc_matrix(Q),
                    q=p,
                    A=csc_matrix(my_A),
                    l=my_l,
                    u=my_u,
                    verbose=False,
                    eps_prim_inf=tol,
                )
                start_time = time.time()
                results = solver.solve()
                end_time = time.time()

                total_time += end_time - start_time
                if results.info.status == "solved":
                    Y.append(results.x)
                else:
                    Y.append(np.ones(self.ydim) * np.nan)

            sols = np.array(Y)
            parallel_time = total_time / len(G)

        else:
            raise NotImplementedError

        return sols, total_time, parallel_time

    def calc_Y(self):
        Y = self.opt_solve(self.X)[0]
        feas_mask = ~np.isnan(Y).all(axis=1)
        self._num = feas_mask.sum()
        self._X = self._X[feas_mask]
        self._Y = torch.tensor(Y[feas_mask])
        return Y

class QPProblemVaryingGbd(SimpleProblem):
    def __init__(self, Q, p, A, G_base, G_varying, b, d, n_varying_rows, valid_frac=0.1, test_frac=0.1):
        super().__init__(Q, p, A, G_varying, b, d, valid_frac, test_frac)
        self.G_base = torch.tensor(G_base)
        self.n_varying_rows = n_varying_rows
        # Flatten the rows of G that are varying, to be added to the NN input.
        G_flattened = self.G[:, :n_varying_rows, :].flatten(start_dim=1)
        self._X = torch.concat([G_flattened, self.b, self.d], dim=1)
        self._num = self._X.shape[0]
        self._neq = self._A.shape[0]
        # G now has num_samples in first dimension, num_constraints in second dimension. Take second dimension!
        self._nineq = self._G.shape[1]
        self._xdim = self._X.shape[1]

    def eq_resid(self, X, Y):
        """B is now varying, we should extract it from X"""
        G, b, d = self.rebuild_Gbd_from_X(X)
        return b - Y @ self.A.T

    def rebuild_Gbd_from_X(self, X):
        # Reshape X to match the first self.n_varying_rows rows of G
        G_size = self.n_varying_rows*self.ydim
        b_size = self.neq
        flattened_G = X[:, :G_size]
        b = X[:, G_size:G_size+b_size]
        d = X[:, G_size+b_size:]
        custom_G = flattened_G.reshape(X.shape[0], self.n_varying_rows, self._ydim)  # Reshape for the batch size

        # Take only the first sample of G and clone it for modification
        G = self.G_base.clone()  # Shape is (M, P)

        # Repeat G for the batch size to avoid memory overlap
        G = G.unsqueeze(0).repeat(X.shape[0], 1, 1)  # Shape is (batch_size, M, P)

        # Assign custom_G to the first self.n_varying_rows rows of G
        G[:, :self.n_varying_rows, :] = custom_G  # Ensure dimensions match
        return G, b, d
    
    def ineq_resid(self, X, Y):
        """
        For the ineq resid, we need to extract the first n rows of the G matrix from it's flattened form X, and plug them into G.
        """

        G, b, d = self.rebuild_Gbd_from_X(X)

        # resid = Y @ G.transpose(1, 2) - h
        residual = torch.bmm(Y.unsqueeze(1), G.transpose(1, 2)).squeeze(1) - d

        # Compute inequality residual
        return residual
    
    def opt_solve(self, X, solver_type="osqp", tol=1e-4):
        """We change op_solve so that the varying G matrices are taken from the input X.
        """
        if solver_type == "osqp":
            print("running osqp")
            Q, p, b, d = self.Q_np, self.p_np, self.b_np, self.d_np
            G, b, d = self.rebuild_Gbd_from_X(X)
            G_np, b_np, d_np = G.detach().cpu().numpy(), b.detach().cpu().numpy(), d.detach().cpu().numpy()
            A = self.A_np
            Y = []
            total_time = 0
            for idx, Gi in enumerate(G_np):
                solver = osqp.OSQP()
                my_A = np.vstack([A, Gi])
                my_l = np.hstack([b_np[idx], -np.ones(d_np[idx].shape[0]) * np.inf])
                my_u = np.hstack([b_np[idx], d_np[idx]])
                solver.setup(
                    P=csc_matrix(Q),
                    q=p,
                    A=csc_matrix(my_A),
                    l=my_l,
                    u=my_u,
                    verbose=False,
                    eps_prim_inf=tol,
                )
                start_time = time.time()
                results = solver.solve()
                end_time = time.time()

                total_time += end_time - start_time
                if results.info.status == "solved":
                    Y.append(results.x)
                else:
                    Y.append(np.ones(self.ydim) * np.nan)

            sols = np.array(Y)
            parallel_time = total_time / len(X)

        else:
            raise NotImplementedError

        return sols, total_time, parallel_time

    def calc_Y(self):
        Y = self.opt_solve(self.X)[0]
        feas_mask = ~np.isnan(Y).all(axis=1)
        self._num = feas_mask.sum()
        self._X = self._X[feas_mask]
        self._Y = torch.tensor(Y[feas_mask])
        return Y
    

class ScaledLPProblem(SimpleProblem):
    def __init__(self, Q, p, A, G, b, d, obj_coeff, valid_frac=0.1, test_frac=0.1):
        super().__init__(Q, p, A, G, b, d, valid_frac, test_frac)

        self._X = self._b
        self._c = torch.tensor(obj_coeff)
        self._num = self._X.shape[0]
        self._neq = self._A.shape[0]
        self._nineq = self._G.shape[0]
        self._xdim = self._X.shape[1]

    def eq_resid(self, X, Y):
        return X - Y @ self.A.T

    def ineq_resid(self, X, Y):
        return Y @ self.G.T - self.d

    def obj_fn(self, Y):
        return Y @ self._c.T

    
    def opt_solve(self, X, solver_type="osqp", tol=1e-4):
        if solver_type == "osqp":
            print("running osqp")
            Q, p, A, G, d = self.Q_np, self.p_np, self.A_np, self.G_np, self.d_np
            c = self._c.numpy()
            X_np = X.detach().cpu().numpy()
            Y = []
            total_time = 0
            zero_Q = np.zeros((c.shape[0], c.shape[0]))

            for Xi in X_np:
                solver = osqp.OSQP()
                my_A = np.vstack([A, G])
                my_l = np.hstack([Xi, -np.ones(d.shape[0]) * np.inf])
                my_u = np.hstack([Xi, d])
                solver.setup(
                    q=c,
                    A=csc_matrix(my_A),
                    l=my_l,
                    u=my_u,
                    verbose=False,
                    eps_prim_inf=tol,
                )
                start_time = time.time()
                results = solver.solve()
                end_time = time.time()

                total_time += end_time - start_time
                if results.info.status == "solved":
                    Y.append(results.x)
                else:
                    Y.append(np.ones(self.ydim) * np.nan)

                sols = np.array(Y)
                parallel_time = total_time / len(X_np)
        else:
            raise NotImplementedError

        return sols, total_time, parallel_time

    def calc_Y(self):
        Y = self.opt_solve(self.X)[0]
        feas_mask = ~np.isnan(Y).all(axis=1)
        self._num = feas_mask.sum()
        self._X = self._X[feas_mask]
        self._Y = torch.tensor(Y[feas_mask])
        return Y

In [129]:
def create_QP_dataset(num_var, num_ineq, num_eq, num_examples):
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    X = np.random.uniform(-1, 1, size=(num_examples, num_eq))
    G = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))
    h = np.sum(np.abs(G@np.linalg.pinv(A)), axis=1)

    problem = OriginalQPProblem(Q, p, A, G, X, h)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/QP_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_varying_G_dataset(num_var, num_ineq, num_eq, num_examples, vary):
    """Creates a modified QP data set that differs in the inequality constraint matrix, instead of the RHS variables.
    """
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    # X is the same for all samples:
    b = np.random.uniform(-1, 1, size=(num_eq))
    G = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))
    d = np.sum(np.abs(G@np.linalg.pinv(A)), axis=1)

    X = np.random.normal(loc=0, scale=1., size=(num_examples, num_ineq))

    # Try first with changing a single row!
    if vary == 'row':
        row_indices = [0] * num_ineq
        col_indices = list(range(num_ineq))
    if vary == 'column':
        col_indices = [0] * num_ineq
        row_indices = list(range(num_ineq))
    if vary == 'random':
        col_indices = np.random.choice(num_var, num_ineq, replace=False)
        row_indices = np.random.choice(num_ineq, num_ineq, replace=True)

    problem = QPProblemVaryingG(X=torch.tensor(X), Q=Q, p=p, A=A, G=G, b=b, d=d, row_indices=row_indices, col_indices=col_indices)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/Varying_G_type={}_dataset_var{}_ineq{}_eq{}_ex{}".format(vary, num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_varying_G_b_d_dataset(num_var, num_ineq, num_eq, num_examples, num_varying_rows):
    """Creates a modified QP data set that differs in the inequality constraint matrix, instead of the RHS variables.
    """
    np.random.seed(17)
    Q = np.diag(np.random.random(num_var))
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
    # X is the same for all samples:
    B = np.random.uniform(-1, 1, size=(num_examples, num_eq))
    G_base = np.random.normal(loc=0, scale=1., size=(num_ineq, num_var))

    G_list = []
    # For each sample, create a different inequality constraint matrix
    for _ in range(num_examples):
        G_sample = G_base.copy()
        # Vary the first n rows, (specified by num_varying_rows).
        G_sample[:num_varying_rows, :] = np.random.normal(loc=0, scale=1., size=(num_varying_rows, num_var))
        G_list.append(G_sample)
    
    # Create H matrix for each example
    D_list = []
    for Gi in G_list:
        d = np.sum(np.abs(Gi @ np.linalg.pinv(A)), axis=1)  # Compute bounds for all inequalities
        D_list.append(d)  # Resulting shape will be (num_ineq,)

    G = np.array(G_list)
    D = np.stack(D_list, axis=0)  # Shape (num_examples, num_ineq)
    problem = QPProblemVaryingGbd(Q=Q, p=p, A=A, G_base=G_base, G_varying=G, b=B, d=D, n_varying_rows=num_varying_rows)
    problem.calc_Y()
    print(len(problem.Y))

    with open("./QP_data/modified/MODIFIED_random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        pickle.dump(problem, f)
    
    return problem

def create_scaled_QP_problem(num_var, num_ineq, num_eq, num_examples, scale='normal'):
    if scale == 'normal':
        obj_scale = 1
        var_scale = 1
        rhs_scale = 1
    elif scale == 'large':
        obj_scale = 1e9
        var_scale = 1e3
        rhs_scale = 1e3

    np.random.seed(17)

    Q = np.diag(np.random.random(num_var)) * obj_scale
    p = np.random.random(num_var)
    A = np.random.normal(loc=0, scale=var_scale, size=(num_eq, num_var))
    X = np.random.uniform(-rhs_scale, rhs_scale, size=(num_examples, num_eq))
    G = np.random.normal(loc=0, scale=var_scale, size=(num_ineq, num_var))
    h = np.sum(np.abs(G@np.linalg.pinv(A)), axis=1)

    problem = OriginalQPProblem(Q, p, A, G, X, h)
    problem.calc_Y()
    print(len(problem.Y))

    # with open("./QP_data/original/random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(num_var, num_ineq, num_eq, num_examples), 'wb') as f:
        # pickle.dump(problem, f)
    
    return problem

In [178]:
DTYPE = torch.float64
DEVICE = torch.device="cpu"
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(42)
print(f"Running on {DEVICE}")
        
class PrimalDualTrainer():

    def __init__(self, data, args, save_dir):
        """_summary_

        Args:
            data (_type_): _description_
            args (_type_): _description_
            save_dir (_type_): _description_
            problem_type (str, optional): Either "GEP" or "Benchmark". Defaults to "GEP".
            log (bool, optional): _description_. Defaults to True.
        """

        print(f"X dim: {data.xdim}")
        print(f"Y dim: {data.ydim}")

        print(f"Size of mu: {data.nineq}")
        print(f"Size of lambda: {data.neq}")

        self.data = data
        self.args = args
        self.save_dir = save_dir

        self.outer_iterations = args["outer_iterations"]
        self.inner_iterations = args["inner_iterations"]
        self.tau = args["tau"]
        self.rho = args["rho"]
        self.rho_max = args["rho_max"]
        self.alpha = args["alpha"]
        self.batch_size = args["batch_size"]
        self.hidden_sizes = args["hidden_sizes"]

        self.primal_lr = args["primal_lr"]
        self.dual_lr = args["dual_lr"]
        self.decay = args["decay"]
        self.patience = args["patience"]
        
        # for logging
        self.step = 0

        X = data.X

        train = data.train_indices
        valid = data.valid_indices
        test = data.test_indices

        # Traning data in a data set
        #! Vary per experiment
        self.train_dataset = TensorDataset(X[train].to(DEVICE))
        self.valid_dataset = TensorDataset(X[valid].to(DEVICE))
        self.test_dataset = TensorDataset(X[test].to(DEVICE))

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=1000, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=1000, shuffle=False)

        self.primal_net = PrimalNet(self.data, self.hidden_sizes).to(dtype=DTYPE, device=DEVICE)
        self.dual_net = DualNet(self.data, self.hidden_sizes, self.data.nineq, self.data.neq).to(dtype=DTYPE, device=DEVICE)

        self.primal_optim = torch.optim.Adam(self.primal_net.parameters(), lr=self.primal_lr)
        self.dual_optim = torch.optim.Adam(self.dual_net.parameters(), lr=self.dual_lr)

        # Add schedulers
        self.primal_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.primal_optim, mode='min', factor=self.decay, patience=self.patience
        )
        self.dual_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.dual_optim, mode='min', factor=self.decay, patience=self.patience
        )

    def train_PDL(self, max_violation_save_thresholds=[0.005, 0.006, 0.007, 0.008, 0.009, 0.01]):
        try:
            best_val_losses = [0] * len(max_violation_save_thresholds)
            prev_v_k = 0
            training_time = 0
            stats = {}
            stats["training_time"] = {}
            for k in range(self.outer_iterations):
                begin_time = time.time()
                frozen_dual_net = copy.deepcopy(self.dual_net)
                # self.logger.log_rho_vk(self.rho, prev_v_k, self.step)
                for l1 in range(self.inner_iterations):
                    self.step += 1
                    # Update primal net using primal loss
                    self.primal_net.train()

                    # Accumulate training loss over all batches
                    for Xtrain in self.train_loader:
                        Xtrain = Xtrain[0]
                        batch_start = time.time()
                        self.primal_optim.zero_grad()
                        y = self.primal_net(Xtrain)
                        with torch.no_grad():
                            mu, lamb = frozen_dual_net(Xtrain)
                        batch_loss = self.primal_loss(Xtrain, y, mu, lamb).mean()
                        batch_loss.backward()
                        self.primal_optim.step()
                        training_time += time.time() - batch_start


                    # Evaluate validation loss every epoch, and update learning rate
                    with torch.no_grad():
                        self.primal_net.eval()
                        frozen_dual_net.eval()
                        obj_val_mean, val_loss_mean, ineq_max, ineq_mean, eq_max, eq_mean = self.evaluate(self.valid_dataset.tensors[0], self.primal_net, self.dual_net)    
                        
                        # Normalize by rho, so that the schedular still works correctly if rho is increased
                        self.primal_scheduler.step(torch.sign(val_loss_mean) * (torch.abs(val_loss_mean) / self.rho))

                        # Save if best model:
                        for i in range(len(max_violation_save_thresholds)):
                            if ineq_max < max_violation_save_thresholds[i] \
                            and eq_max < max_violation_save_thresholds[i] \
                            and obj_val_mean < best_val_losses[i]:
                                print(f"Saving new model with obj: {obj_val_mean}, eq_max: {eq_max}, ineq_max: {ineq_max}, eq_mean: {eq_mean}, ineq_mean: {ineq_mean}")
                                with open(os.path.join(self.save_dir, f'{max_violation_save_thresholds[i]}_primal_net.dict'), 'wb') as f:
                                    torch.save(self.primal_net.state_dict(), f)
                                with open(os.path.join(self.save_dir, f'{max_violation_save_thresholds[i]}_dual_net.dict'), 'wb') as f:
                                    torch.save(self.dual_net.state_dict(), f)
                                best_val_losses[i] = obj_val_mean

                                stats["training_time"][f"{max_violation_save_thresholds[i]}"] = training_time
                
                with torch.no_grad():
                    # Copy primal net into frozen primal net
                    frozen_primal_net = copy.deepcopy(self.primal_net)
                    X = self.train_dataset.tensors[0]
                    # Calculate v_k
                    y = frozen_primal_net(X)
                    mu_k, lamb_k = frozen_dual_net(X)
                    v_k = self.violation(X, y, mu_k)

                for l in range(self.inner_iterations):
                    self.step += 1
                    # Update dual net using dual loss
                    self.dual_net.train()
                    frozen_primal_net.train()
                    for Xtrain in self.train_loader:
                        Xtrain = Xtrain[0]
                        batch_start = time.time()
                        self.dual_optim.zero_grad()
                        mu, lamb = self.dual_net(Xtrain)
                        with torch.no_grad():
                            mu_k, lamb_k = frozen_dual_net(Xtrain)
                            y = frozen_primal_net(Xtrain)
                        batch_loss = self.dual_loss(Xtrain, y, mu, lamb, mu_k, lamb_k).mean()
                        batch_loss.backward()
                        self.dual_optim.step()
                        training_time += time.time() - batch_start

                    # Evaluate validation loss every epoch, and update learning rate
                    with torch.no_grad():
                        frozen_primal_net.eval()
                        self.dual_net.eval()
                        obj_val_mean, val_loss_mean, ineq_max, ineq_mean, eq_max, eq_mean = self.evaluate(self.valid_dataset.tensors[0], self.primal_net, self.dual_net)    
                        # Normalize by rho, so that the schedular still works correctly if rho is increased
                        self.dual_scheduler.step(torch.sign(val_loss_mean) * (torch.abs(val_loss_mean) / self.rho))

                end_time = time.time()
                print("-"*40)
                print(f"Epoch {k} done. Time taken: {end_time - begin_time}. Rho: {self.rho}. Primal LR: {self.primal_optim.param_groups[0]['lr']}, Dual LR: {self.dual_optim.param_groups[0]['lr']}")

                # Update rho from the second iteration onward.
                if k > 0 and v_k > self.tau * prev_v_k:
                    self.rho = np.min([self.alpha * self.rho, self.rho_max])

                prev_v_k = v_k
            
                print(f"Validation set evaluate:")
                with torch.no_grad():
                    self.primal_net.eval()
                    self.dual_net.eval()
                    obj_val_mean, val_loss_mean, ineq_max, ineq_mean, eq_max, eq_mean = self.evaluate(self.valid_dataset.tensors[0], self.primal_net, self.dual_net)
                    print(f"obj_val_mean: {obj_val_mean}, val_loss_mean: {val_loss_mean}, ineq_max: {ineq_max}, ineq_mean: {ineq_mean}, eq_max: {eq_max}, eq_mean: {eq_mean}")    
                    
            print("-"*40)
            print(f"Test set evaluate:")
            with torch.no_grad():
                self.primal_net.eval()
                self.dual_net.eval()
                obj_val_mean, test_loss_mean, ineq_max, ineq_mean, eq_max, eq_mean = self.evaluate(self.test_dataset.tensors[0], self.primal_net, self.dual_net)
                print(f"obj_val_mean: {obj_val_mean}, val_loss_mean: {test_loss_mean}, ineq_max: {ineq_max}, ineq_mean: {ineq_mean}, eq_max: {eq_max}, eq_mean: {eq_mean}")
        
        except Exception as e:
            print(e, flush=True)
            # # Ensure writer is closed even if an exception occurs
            # if self.logger:
            #     self.logger.close()
            raise

        with open(os.path.join(self.save_dir, 'stats.dict'), 'wb') as f:
            pickle.dump(stats, f)
        with open(os.path.join(self.save_dir, 'primal_net.dict'), 'wb') as f:
            torch.save(self.primal_net.state_dict(), f)
        with open(os.path.join(self.save_dir, 'dual_net.dict'), 'wb') as f:
            torch.save(self.dual_net.state_dict(), f)

        return self.primal_net, self.dual_net, stats

    def evaluate(self, X, primal_net, dual_net):        
        # Forward pass through networks
        Y = primal_net(X)
        mu, lamb = dual_net(X)

        ineq_dist = self.data.ineq_dist(X, Y)
        eq_resid = self.data.eq_resid(X, Y)

        # Convert lists to arrays for easier handling
        obj_values = self.data.obj_fn(Y).detach()
        primal_losses = self.primal_loss(X, Y, mu, lamb).detach()
        ineq_max_vals = torch.max(ineq_dist, dim=1)[0].detach()
        ineq_mean_vals = torch.mean(ineq_dist, dim=1).detach()
        eq_max_vals = torch.max(torch.abs(eq_resid), dim=1)[0].detach()
        eq_mean_vals = torch.mean(torch.abs(eq_resid), dim=1).detach()

        return torch.mean(obj_values), torch.mean(primal_losses), torch.mean(ineq_max_vals), torch.mean(ineq_mean_vals), torch.mean(eq_max_vals), torch.mean(eq_mean_vals)



    def primal_loss(self, x, y, mu, lamb):
        obj = self.data.obj_fn(y)
        
        # g(y)
        ineq = self.data.ineq_resid(x, y)
        # h(y)
        eq = self.data.eq_resid(x, y)

        # ! Clamp mu?
        # Element-wise clamping of mu_i when g_i (ineq) is negative
        # mu = torch.where(ineq < 0, torch.zeros_like(mu), mu)
        # ! Clamp ineq_resid?
        # ineq = ineq.clamp(min=0)

        lagrange_ineq = torch.sum(mu * ineq, dim=1)  # Shape (batch_size,)

        lagrange_eq = torch.sum(lamb * eq, dim=1)   # Shape (batch_size,)

        violation_ineq = torch.sum(torch.maximum(ineq, torch.zeros_like(ineq)) ** 2, dim=1)
        violation_eq = torch.sum(eq ** 2, dim=1)
        penalty = self.rho/2 * (violation_ineq + violation_eq)

        loss = (obj + (lagrange_ineq + lagrange_eq + penalty))

        return loss

    def dual_loss(self, x, y, mu, lamb, mu_k, lamb_k):
        # mu = [batch, g]
        # lamb = [batch, h]

        # g(y)
        ineq = self.data.ineq_resid(x, y) # [batch, g]
        # h(y)
        eq = self.data.eq_resid(x, y)   # [batch, h]

        #! From 2nd PDL paper, fix to 1e-1, not rho
        target_mu = torch.maximum(mu_k + self.rho * ineq, torch.zeros_like(ineq))
        # target_mu = torch.maximum(mu_k + 1e-1 * ineq, torch.zeros_like(ineq))

        dual_resid_ineq = mu - target_mu # [batch, g]

        dual_resid_ineq = torch.norm(dual_resid_ineq, dim=1)  # [batch]

        # Compute the dual residuals for equality constraints
        #! From 2nd PDL paper, fix to 1e-1, not rho
        dual_resid_eq = lamb - (lamb_k + self.rho * eq)
        # dual_resid_eq = lamb - (lamb_k + 1e-1 * eq)
        dual_resid_eq = torch.norm(dual_resid_eq, dim=1)  # Norm along constraint dimension

        loss = (dual_resid_ineq + dual_resid_eq)

        return loss

    def violation(self, x, y, mu_k):
        # Calculate the equality constraint function h_x(y)
        eq = self.data.eq_resid(x, y)  # Assume shape (num_samples, n_eq)
        
        # Calculate the infinity norm of h_x(y)
        eq_inf_norm = torch.abs(eq).max(dim=1).values  # Shape: (num_samples,)

        # Calculate the inequality constraint function g_x(y)
        ineq = self.data.ineq_resid(x, y)  # Assume shape (num_samples, n_ineq)
        
        # Calculate sigma_x(y) for each inequality constraint
        sigma_y = torch.maximum(ineq, -mu_k / self.rho)  # Element-wise max
        
        # Calculate the infinity norm of sigma_x(y)
        sigma_y_inf_norm = torch.abs(sigma_y).max(dim=1).values  # Shape: (num_samples,)

        # Compute v_k as the maximum of the two norms
        v_k = torch.maximum(eq_inf_norm, sigma_y_inf_norm)  # Shape: (num_samples,)
        
        return v_k.max().item()

class PrimalNet(nn.Module):
    def __init__(self, data, hidden_sizes):
        super().__init__()
        self._data = data
        self._hidden_sizes = hidden_sizes
        
        # Create the list of layer sizes
        layer_sizes = [data.xdim] + self._hidden_sizes + [data.ydim]
        layers = []

        # Create layers dynamically based on the provided hidden_sizes
        for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            if out_size != data.ydim:  # Add ReLU activation for hidden layers only
                layers.append(nn.ReLU())

        # Initialize all layers
        for layer in layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)

        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

class DualNet(nn.Module):
    def __init__(self, data, hidden_sizes, mu_size, lamb_size):
        super().__init__()
        self._data = data
        self._hidden_sizes = hidden_sizes
        self._mu_size = mu_size
        self._lamb_size = lamb_size

        # Create the list of layer sizes
        layer_sizes = [data.xdim] + self._hidden_sizes
        # layer_sizes = [2*data.xdim + 1000] + self._hidden_sizes
        layers = []
        # Create layers dynamically based on the provided hidden_sizes
        for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(nn.ReLU())

        # Initialize all layers
        for layer in layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)

        # Add the output layer
        self.out_layer = nn.Linear(self._hidden_sizes[-1], self._mu_size + self._lamb_size)
        nn.init.zeros_(self.out_layer.weight)  # Initialize output layer weights to 0
        nn.init.zeros_(self.out_layer.bias)    # Initialize output layer biases to 0
        layers.append(self.out_layer)

        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.net(x)
        out_mu = out[:, :self._mu_size]
        out_lamb = out[:, self._mu_size:]
        return out_mu, out_lamb

Running on cpu


Main Script:

In [168]:
num_var = 100
num_ineq = 50
num_eq = 50
num_examples = 10000

args = {
    "outer_iterations": 10,
    "inner_iterations": 500,
    "tau": 0.8,
    "rho": 0.5,
    "rho_max": 5000,
    "alpha": 10,
    "batch_size": 100,
    "hidden_sizes": [500, 500],
    "primal_lr": 1e-4,
    "dual_lr": 1e-4,
    "decay": 0.99,
    "patience": 10,
    "corrEps": 1e-4,
    "train": 0.001,
    "valid": 0.001,
    "test": 0.998
}

save_dir = "benchmark_experiment_output"

In [160]:
original_data = create_QP_dataset(num_var, num_ineq, num_eq, num_examples)
varying_cm_row_data = create_varying_G_dataset(num_var, num_ineq, num_eq, num_examples, vary='row')
varying_cm_column_data = create_varying_G_dataset(num_var, num_ineq, num_eq, num_examples, vary='column')
varying_cm_random_data = create_varying_G_dataset(num_var, num_ineq, num_eq, num_examples, vary='random')


running osqp
10000
running osqp
10000
running osqp
10000
running osqp
10000


In [180]:
# original_trainer = PrimalDualTrainer(original_data, args, os.path.join(save_dir, 'original'))
# original_primal_net, original_dual_net, _ = original_trainer.train_PDL()

max_violation_save_thresholds = [0.07, 0.08, 0.09, 0.1, 0.15, 0.2]

# row_trainer = PrimalDualTrainer(varying_cm_row_data, args, os.path.join(save_dir, 'row'))
# row_primal_net, row_dual_net, _ = row_trainer.train_PDL(max_violation_save_thresholds)

column_trainer = PrimalDualTrainer(varying_cm_column_data, args, os.path.join(save_dir, 'column'))
col_primal_net, col_dual_net, _ = column_trainer.train_PDL(max_violation_save_thresholds)

random_trainer = PrimalDualTrainer(varying_cm_random_data, args, os.path.join(save_dir, 'random'))
random_primal_net, random_dual_net, _ = random_trainer.train_PDL(max_violation_save_thresholds)

X dim: 50
Y dim: 100
Size of mu: 50
Size of lambda: 50
----------------------------------------
Epoch 0 done. Time taken: 775.1512258052826. Rho: 0.5. Primal LR: 0.0001, Dual LR: 6.361854860638712e-05
Validation set evaluate:
obj_val_mean: -16.458329635287434, val_loss_mean: -13.759737007095561, ineq_max: 0.7137197410857282, ineq_mean: 0.03306536300423989, eq_max: 0.6119022824228045, eq_mean: 0.17830756601771347
Saving new model with obj: -14.48941995040052, eq_max: 0.06487787776394617, ineq_max: 0.16200696198847647, eq_mean: 0.022306579775646845, ineq_mean: 0.0039226063106737995
Saving new model with obj: -14.440236794918773, eq_max: 0.06113034792477724, ineq_max: 0.1245159184839725, eq_mean: 0.021086430468376723, ineq_mean: 0.00294780348802249
Saving new model with obj: -14.295682306696476, eq_max: 0.06203103653624711, ineq_max: 0.09875123152654931, eq_mean: 0.020167471459301704, ineq_mean: 0.002299832410359756
Saving new model with obj: -14.299839324744182, eq_max: 0.060197481925799

In [177]:
for model, data in [(original_primal_net, original_data), (row_primal_net, varying_cm_row_data), (col_primal_net, varying_cm_column_data), (random_primal_net, varying_cm_random_data)]:
    Y_pred = model(data.X)
    obj_known = data.obj_fn(data.Y).detach().cpu().numpy()
    obj_pred = data.obj_fn(Y_pred).detach().cpu().numpy()
    obj_gap = ((obj_known - obj_pred)/obj_known).mean()

    ineq_dist = data.ineq_dist(data.X, Y_pred)
    eq_resid = data.eq_resid(data.X, Y_pred)

    ineq_max_vals = torch.max(ineq_dist, dim=1)[0].detach().cpu().numpy().mean()
    ineq_mean_vals = torch.mean(ineq_dist, dim=1).detach().cpu().numpy().mean()
    eq_max_vals = torch.max(torch.abs(eq_resid), dim=1)[0].detach().cpu().numpy().mean()
    eq_mean_vals = torch.mean(torch.abs(eq_resid), dim=1).detach().cpu().numpy().mean()
    
    print(obj_known.mean(), obj_pred.mean(), obj_gap, ineq_max_vals, ineq_mean_vals, eq_max_vals, eq_mean_vals)

-15.03574072137141 -14.159263032697682 0.05836505624044085 0.00017314761295659365 4.273874887204891e-06 0.0013226164497767041 0.0004215035947090046
-15.571241360476339 -15.346537498605803 0.014434519866945456 0.029326313612115122 0.0008174214563382847 0.0019247002016841645 0.0005937570470690621
-15.353276156344776 -14.371511671264603 0.06354776121051121 0.02001070985463688 0.0004715649132288072 0.0008491038598132157 0.00028472452067880653
-15.609351985154028 -14.481196481463954 0.07228137394973655 0.05245784645211027 0.0012259314204511058 0.0015708231500230776 0.0004784480383306505


In [166]:

with open("benchmark_experiment_output/original/stats.dict", "rb") as f:
    stats = pickle.load(f)
    print(stats)

{'time': {}}
