In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from tqdm import tqdm

import os
import torch
import numpy as np

In [None]:
from data.dataset import LPDataset

root = 'datasets/qp_svm_50_30_0.1'
new_root = root + '_ipm2'
os.mkdir(new_root)
os.mkdir(os.path.join(new_root, 'processed'))

ds = LPDataset(root, transform=None)

In [None]:
from data.utils import recover_qp_from_data
from qpsolvers.conversions.linear_from_box_inequalities import linear_from_box_inequalities
from qpsolvers.solvers.cvxopt_ import __to_cvxopt
from solver.qp import coneqp

from torch_geometric.data import Batch, HeteroData

In [None]:
new_graphs = []
step_length = 8

pbar = tqdm(ds)
for data in pbar:
    P, q, A, b, G, h, lb, ub = recover_qp_from_data(data, np.float64)
    G, h = linear_from_box_inequalities(G, h, lb, ub, use_sparse=False)
    P = __to_cvxopt(P)
    q = __to_cvxopt(q)
    if G is not None and h is not None:
        G = __to_cvxopt(G)
        h = __to_cvxopt(h)
    if A is not None and b is not None:
        A = __to_cvxopt(A)
        b = __to_cvxopt(b)
    initvals_dict = {"x": __to_cvxopt(data.x_feasible.numpy().astype(np.float64))}
    res = coneqp(P, q, G=G, h=h, A=A, b=b, initvals=initvals_dict)
    x_solution = np.array(res['x']).flatten()

    trajectory = np.stack(res['trajectory'][1:], axis=1)
    if step_length > trajectory.shape[1]:
        trajectory = np.concatenate([trajectory, np.tile(trajectory[:, -1], (step_length - trajectory.shape[1], 1)).T], axis=1)
    else:
        trajectory = trajectory[:, np.linspace(0, trajectory.shape[1] - 1, step_length).astype(np.int64)]

    data = HeteroData(
                cons={
                    'num_nodes': data['cons'].num_nodes,
                    'x': data['cons'].x,
                     },
                vals={
                    'num_nodes': data['vals'].num_nodes,
                    'x': data['vals'].x,
                },
                obj={
                    'num_nodes': 1,
                    'x': torch.zeros(1, 1).float(),
                },
                cons__to__vals={'edge_index': data[('cons', 'to', 'vals')].edge_index,
                                'edge_attr': data[('cons', 'to', 'vals')].edge_attr},
                vals__to__vals={'edge_index': data[('vals', 'to', 'vals')].edge_index,
                                'edge_attr': data[('vals', 'to', 'vals')].edge_attr},
                obj__to__vals={'edge_index': torch.vstack([torch.zeros(data['vals'].num_nodes).long(),
                                                           torch.arange(data['vals'].num_nodes)]),
                                'edge_attr': torch.ones(data['vals'].num_nodes, 1).float()},
                obj__to__cons={'edge_index': torch.vstack([torch.zeros(data['cons'].num_nodes).long(),
                                                           torch.arange(data['cons'].num_nodes)]),
                                'edge_attr': torch.ones(data['cons'].num_nodes, 1).float()},
                x_solution=torch.from_numpy(x_solution).float(),
                x_feasible=data.x_feasible,
                trajectory=torch.from_numpy(trajectory).float(),
                obj_solution=data.obj_solution,
                b=data.b,
                q=data.q,
            )
    new_graphs.append(data)

In [None]:
torch.save(Batch.from_data_list(new_graphs), f'{new_root}/processed/batch0.pt')

In [None]:
ds = LPDataset(new_root, transform=None)