In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from tqdm import tqdm

import os
import torch
import warnings
import numpy as np
from scipy.linalg import LinAlgError

In [None]:
root = 'datasets/qp_200_50eq_full'
os.mkdir(root)
os.mkdir(os.path.join(root, 'processed'))

# create QP

In [None]:
from scipy.linalg import qr
from torch_geometric.data import Batch, HeteroData, InMemoryDataset
from qpsolvers import solve_qp
from scipy.optimize import linprog

We make exactly the same dataset for the code base IPM-LSTM  
See their code base for more details: https://github.com/NetSysOpt/IPM-LSTM

In [None]:
num_var = 200
num_eq = 50

In [None]:
np.random.seed(18)
P = np.diag(np.random.random(num_var))
q = np.random.randn(num_var)

np.random.seed(19)
A = np.random.normal(loc=0, scale=1., size=(num_eq, num_var))
# A[np.random.rand(*A.shape) > 0.1] = 0.

bs = np.random.uniform(-1, 1, size=(1000, num_eq))

In [None]:
lmat, _ = qr(A.T)
nulls = lmat[:, m:]

In [None]:
A_torch = torch.from_numpy(A).to(torch.float)
q_torch = torch.from_numpy(q).to(torch.float)
P_torch = torch.from_numpy(P).to(torch.float)

In [None]:
graphs = []
pkg_idx = 0
success_cnt = 0

num = 1000

pbar = tqdm(range(num))
for i in pbar:
    try:
        b = bs[i]

        m, n = A.shape
        sol = linprog(c=np.concatenate([np.zeros(n), np.array([-1.])], axis=0), 
                      A_ub=np.concatenate([-np.eye(n), np.ones((n, 1))], axis=1), 
                      b_ub=np.zeros(n), 
                      A_eq=np.concatenate([A, np.zeros((m, 1))], axis=1), b_eq=b, 
                      # we set upper bound in case unbounded e.g. svm
                      bounds=(0, 5.), method='highs')
        assert sol.success
        x_feasible = sol.x[:-1]

        # should not be too close to 0
        assert np.all(x_feasible >= 0.05) and np.abs(A @ x_feasible - b).max() < 1.e-6        

        solution = solve_qp(P, q, None, None, A, b, lb=np.zeros(n).astype(np.float64), ub=None, solver="osqp")
        assert solution is not None
        obj = 0.5 * solution @ P @ solution + q.dot(solution)
        assert not np.isnan(obj)
    except (AssertionError, LinAlgError):
        continue
    else:
        b = torch.from_numpy(b).to(torch.float)
        solution = torch.from_numpy(solution).to(torch.float)
        x_feasible = torch.from_numpy(x_feasible).to(torch.float)

        A_where = torch.where(A_torch)
        P_where = torch.where(P_torch)

        data = HeteroData(
            cons={
                'num_nodes': A.shape[0],
                'x': torch.empty(A.shape[0]),
                 },
            vals={
                'num_nodes': A.shape[1],
                'x': torch.empty(A.shape[1]),
            },
            # we create a tripartite graph, but we may NOT use the global node
            obj={
                    'num_nodes': 1,
                    'x': torch.zeros(1, 1).float(),
                },
            cons__to__vals={'edge_index': torch.vstack(A_where),
                            'edge_attr': A_torch[A_where][:, None]},
            vals__to__vals={'edge_index': torch.vstack(P_where),
                            'edge_attr': P_torch[P_where][:, None]},
            obj__to__vals={'edge_index': torch.vstack([torch.zeros(A.shape[1]).long(),
                                                       torch.arange(A.shape[1])]),
                            'edge_attr': torch.ones(A.shape[1], 1).float()},
            obj__to__cons={'edge_index': torch.vstack([torch.zeros(A.shape[0]).long(),
                                                       torch.arange(A.shape[0])]),
                            'edge_attr': torch.ones(A.shape[0], 1).float()},
            x_solution=solution,
            x_feasible=x_feasible,
            obj_solution=obj,
            b=b,
            q=q_torch,
            nulls=torch.from_numpy(nulls).float().reshape(-1)
        )
        success_cnt += 1
        graphs.append(data)

    if len(graphs) >= 1000 or success_cnt == num:
        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

    if success_cnt >= num:
        break

    pbar.set_postfix({'suc': success_cnt})

In [None]:
from data.dataset import LPDataset

In [None]:
ds = LPDataset(root, transform=None)