In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from scipy.optimize import linprog
from tqdm import tqdm

import os
import torch
from scipy.linalg import LinAlgWarning, LinAlgError
from scipy.optimize._optimize import OptimizeWarning
from scipy.optimize._linprog_util import _clean_inputs, _get_Abc
import warnings
import numpy as np

from generate_instances_lp import generate_setcover, Graph, generate_indset, generate_cauctions, generate_capacited_facility_location

In [None]:
rng = np.random.RandomState(1)

In [None]:
root = 'datasets/fac_60_5_0.5'
os.mkdir(root)
os.mkdir(os.path.join(root, 'processed'))

### Setcover

In [None]:
density = 0.008
nrows_l = 300
nrows_u = 400
ncols_l = 200
ncols_u = 300

bounds = (0., 1.)

def surrogate_gen():
    nrows = rng.randint(nrows_l, nrows_u)
    ncols = rng.randint(ncols_l, ncols_u)
    nnzrs = int(nrows * ncols * density)
    A, b, c = generate_setcover(nrows, ncols, nnzrs, rng)
    return None, None, A, b, c

### Indset

In [None]:
def surrogate_gen():
    # nnodes = rng.randint(10, 20)
    nnodes = rng.randint(250, 300)
    edge_probability = 0.01
    graph = Graph.erdos_renyi(number_of_nodes=nnodes, edge_probability=edge_probability, random=rng)
    A, b, c = generate_indset(graph=graph, nnodes=nnodes)
    return None, None, A, b, c

bounds = (0., 1.)

### Cauctions

In [None]:
def surrogate_gen():
    # n_items=rng.randint(15, 20)
    # n_bids=rng.randint(15, 20)
    n_items=rng.randint(300, 400)
    n_bids=rng.randint(300, 400)
    A, b, c = generate_cauctions(n_items=n_items, n_bids=n_bids, rng=rng, min_value=0.5, max_value=1., add_item_prob=0.3)
    # c = np.ones_like(c, dtype=np.float32) * -1.
    return None, None, A, b, c

bounds = (0., 1.)

### Facilities

In [None]:
def surrogate_gen():
    n_customers = rng.randint(60, 70)
    n_facilities = 5
    ratio = 0.5
    # min would be like 0.2-ish
    A_eq, b_eq, A_ub, b_ub, c = generate_capacited_facility_location(n_customers=n_customers, 
                                                                     n_facilities=n_facilities, 
                                                                     ratio=ratio, rng=rng)
    return A_eq, b_eq, A_ub, b_ub, c

bounds = (0., 1.)

# create eq

In [None]:
from scipy.linalg import qr
from torch_geometric.data import Batch, HeteroData, InMemoryDataset
from collections import namedtuple
from generate_instances import normalize_cons

_LPProblem = namedtuple('_LPProblem',
                        'c A_ub b_ub A_eq b_eq bounds x0 integrality')
_LPProblem.__new__.__defaults__ = (None,) * 7  # make c the only required arg

In [None]:
warnings.filterwarnings("error")

ips = []
graphs = []
pkg_idx = 0
success_cnt = 0

max_iter = 1500
num = 1000

pbar = tqdm(range(max_iter))
for i in pbar:
    A_eq, b_eq, A_ub, b_ub, c = surrogate_gen()
    c = c / (np.abs(c).max() + 1.e-10)  # does not change the result
    A_eq, b_eq = normalize_cons(A_eq, b_eq)
    A_ub, b_ub = normalize_cons(A_ub, b_ub)

    # process LP into standard form Ax=b, x>=0
    lp = _LPProblem(c, A_ub, b_ub, A_eq, b_eq, bounds, None, None)
    lp = _clean_inputs(lp)
    A, b, c, *_ = _get_Abc(lp, 0.)

    m, n = A.shape
    
    try:
        res = linprog(c, A_eq=A, b_eq=b, bounds=None, method='highs')
        lmat, _ = qr(A.T)
        nulls = lmat[:, m:]
    except (LinAlgWarning, OptimizeWarning, AssertionError, LinAlgError):
        continue
    else:
        if res.success and not np.isnan(res.fun):
            # create graph on the fly
            sol = linprog(c=np.concatenate([np.zeros(n), np.array([-1.])], axis=0), 
                          A_ub=np.concatenate([-np.eye(n), np.ones((n, 1))], axis=1), 
                          b_ub=np.zeros(n), 
                          A_eq=np.concatenate([A, np.zeros((m, 1))], axis=1), b_eq=b, 
                          # we set upper bound in case unbounded e.g. svm
                          bounds=(0, 10.), method='highs')
            assert sol.success
            x_feasible = sol.x[:-1]
            
            # should not be too close to 0
            assert np.all(x_feasible >= 0.05) and np.abs(A @ x_feasible - b).max() < 1.e-6
                
            A = torch.from_numpy(A).to(torch.float)
            b = torch.from_numpy(b).to(torch.float)
            c = torch.from_numpy(c).to(torch.float)
            x = torch.from_numpy(res.x).to(torch.float)
            x_feasible = torch.from_numpy(x_feasible).to(torch.float)

            A_where = torch.where(A)
            data = HeteroData(
                cons={
                    'num_nodes': b.shape[0],
                    'x': torch.empty(b.shape[0]),
                     },
                vals={
                    'num_nodes': c.shape[0],
                    'x': torch.empty(c.shape[0]),
                },
                obj={
                    'num_nodes': 1,
                    'x': torch.zeros(1, 1).float(),
                },
                cons__to__vals={'edge_index': torch.vstack(A_where),
                                'edge_attr': A[A_where][:, None]},
                obj__to__vals={'edge_index': torch.vstack([torch.zeros(A.shape[1]).long(),
                                                           torch.arange(A.shape[1])]),
                                'edge_attr': torch.ones(A.shape[1], 1).float()},
                obj__to__cons={'edge_index': torch.vstack([torch.zeros(A.shape[0]).long(),
                                                           torch.arange(A.shape[0])]),
                                'edge_attr': torch.ones(A.shape[0], 1).float()},
                x_solution=x,
                x_feasible=x_feasible,
                obj_solution=c.dot(x),
                q=c,
                b=b,
                nulls=torch.from_numpy(nulls).float().reshape(-1)
            )
            success_cnt += 1
            graphs.append(data)

    if len(graphs) >= 1000 or success_cnt == num:
        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

    if success_cnt >= num:
        break

    pbar.set_postfix({'suc': success_cnt})

warnings.resetwarnings()

In [None]:
from data.dataset import LPDataset

In [None]:
ds = LPDataset(root)