In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import tqdm

import os
import torch
from torch_sparse import SparseTensor
from data.dataset import LPDataset
import numpy as np
from torch_geometric.data import Batch, HeteroData, Data

import scipy.sparse
from ortools.pdlp import solve_log_pb2
from ortools.pdlp import solvers_pb2
from ortools.pdlp.python import pdlp

In [None]:
target_root = 'datasets/mini_pdlp'
os.mkdir(target_root)
os.mkdir(os.path.join(target_root, 'processed'))

In [None]:
root = 'datasets/mini'
ds = LPDataset(root)

https://github.com/google/or-tools/blob/stable/ortools/pdlp/samples/simple_pdlp_program.py

In [None]:
params = solvers_pb2.PrimalDualHybridGradientParams()
# Below are some common parameters to modify. Here, we just re-assign the
# defaults.
optimality_criteria = params.termination_criteria.simple_optimality_criteria
optimality_criteria.eps_optimal_relative = 1.0e-6
optimality_criteria.eps_optimal_absolute = 1.0e-6
params.termination_criteria.time_sec_limit = np.inf
params.num_threads = 1
params.verbosity_level = 0
params.presolve_options.use_glop = False

In [None]:
graphs = []
pkg_idx = 0

for i, data in enumerate(tqdm(ds)):
    A = SparseTensor(row=data['cons', 'to', 'vals'].edge_index[0],
                 col=data['cons', 'to', 'vals'].edge_index[1],
                 value=data['cons', 'to', 'vals'].edge_attr.squeeze(),
                 is_sorted=True, trust_data=True).to_scipy('csc')

    lp = pdlp.QuadraticProgram()
    # lp.objective_offset = -14
    lp.objective_vector = data.c.tolist()
    lp.constraint_lower_bounds = data.b.tolist()
    lp.constraint_upper_bounds = data.b.tolist()
    lp.variable_lower_bounds = [0] * data.c.shape[0]
    lp.variable_upper_bounds = [np.inf] * data.c.shape[0]
    lp.constraint_matrix = A
    
    result = pdlp.primal_dual_hybrid_gradient(lp, params)
    solve_log = result.solve_log
    
    assert solve_log.termination_reason == solve_log_pb2.TERMINATION_REASON_OPTIMAL
        
    newdata = HeteroData(
        cons={
            'num_nodes': data.b.shape[0],
            'x': torch.empty(data.b.shape[0]),
             },
        vals={
            'num_nodes': data.c.shape[0],
            'x': torch.empty(data.c.shape[0]),
        },
        cons__to__vals={'edge_index': data[('cons', 'to', 'vals')].edge_index,
                        'edge_attr': data[('cons', 'to', 'vals')].edge_attr},
        obj_solution=data.obj_solution,
        c=data.c,
        b=data.b,
        primal_solution=result.primal_solution,
        dual_solution=result.dual_solution)

    graphs.append(newdata)

    if len(graphs) >= 1000:
        torch.save(Batch.from_data_list(graphs), f'{target_root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

if len(graphs):
    torch.save(Batch.from_data_list(graphs), f'{target_root}/processed/batch{pkg_idx}.pt')
    graphs = []

In [None]:
ds = LPDataset(target_root)