In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from tqdm import tqdm

import os
import torch
import warnings
import numpy as np
from scipy.linalg import LinAlgError
from functools import partial

from generate_instances import generic, postprocess, soft_svm, portfolio

In [None]:
rng = np.random.RandomState(1)

In [None]:
root = 'datasets/qp_port_800_0.01_sma2'
os.mkdir(root)
os.mkdir(os.path.join(root, 'processed'))

### generic

In [None]:
A_density = 0.008
P_density = 0.01
nrows = ncols = 400

gen_func = partial(generic, nrows=nrows, ncols=ncols, A_density=A_density, P_density=P_density, rng=rng)

### soft margin SVM

In [None]:
num = 800 # Number of points
dim = 800 # Dimension of the points
lamb = 0.5 # regularization parameter (hardness of the margin)
density = 0.01

gen_func = partial(soft_svm, nums=num, dims=dim, lamb=lamb, density=density, rng=rng)

### portfolio

In [None]:
n_assets = 800
density = 0.01

gen_func = partial(portfolio, n_assets=n_assets, density=density, rng=rng)

# create QP

In [None]:
from scipy.linalg import null_space
from torch_geometric.data import Batch, HeteroData, InMemoryDataset
from qpsolvers import solve_qp
from scipy.optimize import linprog

In [None]:
graphs = []
pkg_idx = 0
success_cnt = 0

max_iter = 2000
num = 1000

pbar = tqdm(range(max_iter))
for i in pbar:
    try:
        A, b, G, h, P, q, lb, ub, success = gen_func()
        assert success

        # x_feasible, *_ = _ip_hsd_feas(A, b, np.zeros(A.shape[1]), 0.,
        #                               alpha0=0.99995, beta=0.1,
        #                               maxiter=100, tol=1.e-6, sparse=True,
        #                               lstsq=False, sym_pos=True, cholesky=None,
        #                               pc=True, ip=True, permc_spec='MMD_AT_PLUS_A',
        #                               rand_start=False)
        m, n = A.shape
        sol = linprog(c=np.concatenate([np.zeros(n), np.array([-1.])], axis=0), 
                      A_ub=np.concatenate([-np.eye(n), np.ones((n, 1))], axis=1), 
                      b_ub=np.zeros(n), 
                      A_eq=np.concatenate([A, np.zeros((m, 1))], axis=1), b_eq=b, bounds=(0, None), method='highs')
        x_feasible = sol.x[:-1]
        assert sol.success
        
        # should not be too close to 0
        assert np.all(x_feasible >= 0.1) and np.abs(A @ x_feasible - b).max() < 1.e-6

        nulls = null_space(A)
        solution = solve_qp(P, q, G, h, A, b, lb=lb, ub=ub, solver="cvxopt")
        assert solution is not None
        obj = 0.5 * solution @ P @ solution + q.dot(solution)
        assert not np.isnan(obj)
    except (AssertionError, LinAlgError):
        continue
    else:        
        A = torch.from_numpy(A).to(torch.float)
        b = torch.from_numpy(b).to(torch.float)
        q = torch.from_numpy(q).to(torch.float)
        solution = torch.from_numpy(solution).to(torch.float)
        x_feasible = torch.from_numpy(x_feasible).to(torch.float)

        P = torch.from_numpy(P).to(torch.float)
        A_where = torch.where(A)
        P_where = torch.where(P)

        data = HeteroData(
            cons={
                'num_nodes': A.shape[0],
                'x': torch.empty(A.shape[0]),
                 },
            vals={
                'num_nodes': A.shape[1],
                'x': torch.empty(A.shape[1]),
            },
            cons__to__vals={'edge_index': torch.vstack(A_where),
                            'edge_attr': A[A_where][:, None]},
            vals__to__vals={'edge_index': torch.vstack(P_where),
                            'edge_attr': P[P_where][:, None]},
            x_solution=solution,
            x_feasible=x_feasible,
            obj_solution=obj,
            b=b,
            q=q,
            nulls=torch.from_numpy(nulls).float().reshape(-1)
        )
        success_cnt += 1
        graphs.append(data)

    if len(graphs) >= 1000 or success_cnt == num:
        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

    if success_cnt >= num:
        break

    pbar.set_postfix({'suc': success_cnt})

In [None]:
from data.dataset import LPDataset

In [None]:
ds = LPDataset(root, transform=None)