In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from solver.linprog import linprog
from tqdm import tqdm

import os
import torch
from torch_sparse import SparseTensor
from data.dataset import LPDataset
import numpy as np
from torch_geometric.data import Batch, HeteroData

In [None]:
target_root = 'datasets/large_setc_ipm'
os.mkdir(target_root)
os.mkdir(os.path.join(target_root, 'processed'))

In [None]:
root = 'datasets/large_setc'
ds = LPDataset(root)

In [None]:
class SubSample:
    def __init__(self, k):
        self.k = k

    def __call__(self, data):
        len_seq = data.gt_primals.shape[1]
        if self.k == 1:
            data.gt_primals = data.gt_primals[:, -1:]
        elif self.k == len_seq:
            return data
        elif self.k > len_seq:
            data.gt_primals = torch.cat([data.gt_primals,
                                         data.gt_primals[:, -1:].repeat(1, self.k - len_seq)], dim=1)
        else:
            data.gt_primals = data.gt_primals[:, np.linspace(1, len_seq - 1, self.k).astype(np.int64)]
        return data

In [None]:
graphs = []
pkg_idx = 0

for i, data in enumerate(tqdm(ds)):
    A = SparseTensor(row=data['cons', 'to', 'vals'].edge_index[0],
                 col=data['cons', 'to', 'vals'].edge_index[1],
                 value=data['cons', 'to', 'vals'].edge_attr.squeeze(),
                 is_sorted=True, trust_data=True).to_dense().numpy()
    sol = linprog(data.c.numpy(),
              A_ub=None,
              b_ub=None,
              A_eq=A, b_eq=data.b.numpy(), bounds=None,
              method='interior-point', callback=lambda res: res.x)
    x = np.stack(sol.intermediate, axis=1)
    assert not np.isnan(sol['fun'])
    
    gt_primals = torch.from_numpy(x).to(torch.float)
    
    newdata = HeteroData(
        cons={
            'num_nodes': data.b.shape[0],
            'x': torch.empty(data.b.shape[0]),
             },
        vals={
            'num_nodes': data.c.shape[0],
            'x': torch.empty(data.c.shape[0]),
        },
        cons__to__vals={'edge_index': data[('cons', 'to', 'vals')].edge_index,
                        'edge_attr': data[('cons', 'to', 'vals')].edge_attr},
        obj_solution=data.obj_solution,
        c=data.c,
        b=data.b,
        gt_primals=gt_primals)

    newdata = SubSample(8)(newdata)
    graphs.append(newdata)

    if len(graphs) >= 1000:
        torch.save(Batch.from_data_list(graphs), f'{target_root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

if len(graphs):
    torch.save(Batch.from_data_list(graphs), f'{target_root}/processed/batch{pkg_idx}.pt')
    graphs = []

In [None]:
ds = LPDataset(target_root)