In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from tqdm import tqdm

import os
import torch
import warnings
import numpy as np
from scipy.linalg import LinAlgError
from functools import partial

from generate_instances import generic
from scipy.linalg import null_space
from torch_geometric.data import Batch, HeteroData, InMemoryDataset
from qpsolvers import solve_qp
from scipy.optimize import linprog
from scipy.sparse import csc_matrix
from scipy.linalg import qr

In [None]:
rng = np.random.RandomState(1)

In [None]:
root = 'datasets/foundation_sparser'
os.mkdir(root)
os.mkdir(os.path.join(root, 'processed'))

### foundation

In [None]:
def gen_func():
    A_density_lb = 0.0001
    A_density_ub = 0.001
    P_density_lb = 0.00001
    P_density_ub = 0.0001
    size_lb = 2000
    size_ub = 3000

    P_density = np.random.uniform(P_density_lb, P_density_ub)
    nrows = ncols = np.random.randint(size_lb, size_ub)
    A_density = max(np.random.uniform(A_density_lb, A_density_ub), 2 / nrows)
    return generic(nrows, ncols, A_density, P_density, rng)

# create QP

In [None]:
graphs = []
pkg_idx = 0
success_cnt = 0

max_iter = 2000
num = 1000
batch_size = 100

pbar = tqdm(range(max_iter))
for i in pbar:
    try:
        A, b, G, h, P, q, lb, ub, success = gen_func()
        P = csc_matrix(P)
        assert success

        m, n = A.shape
        sol = linprog(c=np.concatenate([np.zeros(n), np.array([-1.])], axis=0), 
                      A_ub=np.concatenate([-np.eye(n), np.ones((n, 1))], axis=1), 
                      b_ub=np.zeros(n), 
                      A_eq=np.concatenate([A, np.zeros((m, 1))], axis=1), b_eq=b, 
                      # we set upper bound in case unbounded e.g. svm
                      bounds=(0, 10.), method='highs')
        assert sol.success
        x_feasible = sol.x[:-1]
        
        # should not be too close to 0
        assert np.all(x_feasible >= 0.05) and np.abs(A @ x_feasible - b).max() < 1.e-6

        lmat, _ = qr(A.T)
        nulls = lmat[:, m:]
        
        solution = solve_qp(P, q, G, h, A, b, lb=lb, ub=ub, solver="osqp")
        assert solution is not None
        obj = 0.5 * solution @ P @ solution + q.dot(solution)
        assert not np.isnan(obj)
    except (AssertionError, LinAlgError):
        continue
    else:        
        A = torch.from_numpy(A).to(torch.float)
        b = torch.from_numpy(b).to(torch.float)
        q = torch.from_numpy(q).to(torch.float)
        solution = torch.from_numpy(solution).to(torch.float)
        x_feasible = torch.from_numpy(x_feasible).to(torch.float)

        # use sparse mat here
        P = P.tocoo()
        A_where = torch.where(A)

        data = HeteroData(
            cons={
                'num_nodes': A.shape[0],
                'x': torch.empty(A.shape[0]),
                 },
            vals={
                'num_nodes': A.shape[1],
                'x': torch.empty(A.shape[1]),
            },
            obj={
                    'num_nodes': 1,
                    'x': torch.zeros(1, 1).float(),
                },
            cons__to__vals={'edge_index': torch.vstack(A_where),
                            'edge_attr': A[A_where][:, None]},
            vals__to__vals={'edge_index': torch.from_numpy(np.vstack([P.row, P.col])),
                            'edge_attr': torch.from_numpy(P.data[:, None]).float()},
            obj__to__vals={'edge_index': torch.vstack([torch.zeros(A.shape[1]).long(),
                                                       torch.arange(A.shape[1])]),
                            'edge_attr': torch.ones(A.shape[1], 1).float()},
            obj__to__cons={'edge_index': torch.vstack([torch.zeros(A.shape[0]).long(),
                                                       torch.arange(A.shape[0])]),
                            'edge_attr': torch.ones(A.shape[0], 1).float()},
            x_solution=solution,
            x_feasible=x_feasible,
            obj_solution=obj,
            b=b,
            q=q,
            nulls=torch.from_numpy(nulls).float().reshape(-1)
        )
        success_cnt += 1
        graphs.append(data)

    if len(graphs) >= batch_size or success_cnt == num:
        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

    if success_cnt >= num:
        break

    pbar.set_postfix({'suc': success_cnt})

In [None]:
from data.dataset import LPDataset

In [None]:
ds = LPDataset(root, transform=None)