In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from tqdm import tqdm
from utils.evaluation import gurobi_solve_lp

import os
import torch
from scipy.linalg import LinAlgError
import numpy as np
from torch_geometric.data import Batch, HeteroData
from scipy.sparse import coo_array

from utils.evaluation import data_contraint_heuristic, data_inactive_constraints, normalize_cons

In [None]:
rng = np.random.RandomState(1)

In [None]:
root = 'datasets/gen_250_250_0.02'
os.mkdir(root)
os.mkdir(os.path.join(root, 'processed'))

### Generic

In [None]:
density = 0.025
nrows = 200
ncols = 200

def surrogate_gen():
    assert max(nrows, ncols) * density > 1

    m, n = min(nrows, ncols), max(nrows, ncols)

    # make sure rows and cols are selected at least once
    rows = np.hstack([np.arange(m), np.random.randint(0, m, (n - m,))])
    cols = np.arange(n)

    # generate the rest
    nnz = int(nrows * ncols * density)
    num_rest = nnz - n

    rows_rest = np.random.randint(0, m, (num_rest,))
    cols_rest = np.random.randint(0, n, (num_rest,))

    values = np.random.randn(nnz)

    A = coo_array((values, (np.hstack([rows, rows_rest]), np.hstack([cols, cols_rest]))), shape=(m, n)).toarray()

    x_feas = np.abs(np.random.randn(ncols))  # Ensure x_feas is non-negative
    b = A @ x_feas + np.abs(np.random.randn(nrows))  # Ensure feasibility

    c = np.abs(np.random.randn(ncols))
    return A, b, c

bounds = None

# create ineq

In [None]:
ips = []
graphs = []
pkg_idx = 0
success_cnt = 0

max_iter = 15000
num = 10000

pbar = tqdm(range(max_iter))
for i in pbar:
    A, b, c = surrogate_gen()
    c = c / (np.abs(c).max() + 1.e-10)  # does not change the result
    A, b = normalize_cons(A, b)
    
    try:
        assert np.linalg.matrix_rank(A) == min(*A.shape)
        assert np.all(np.any(A, axis=1)) and np.all(np.any(A, axis=0))
        # res = linprog(c, A_ub=A, b_ub=b, bounds=bounds, method='highs')
        solution, duals = gurobi_solve_lp(A, b, c)
        assert solution is not None
        assert c.dot(solution) != 0.
    except (AssertionError, LinAlgError):
        continue
    else:
        heur_idx = data_contraint_heuristic(None, A, b, c)
        inactive_idx = data_inactive_constraints(A, b, solution)
        inactive_heur_acc = np.isin(heur_idx, inactive_idx).sum() / len(heur_idx)

        A = torch.from_numpy(A).to(torch.float)
        b = torch.from_numpy(b).to(torch.float)
        c = torch.from_numpy(c).to(torch.float)
        x = torch.from_numpy(solution).to(torch.float)

        A_where = torch.where(A)
        data = HeteroData(
            cons={
                'num_nodes': b.shape[0],
                'x': torch.empty(b.shape[0]),
                 },
            vals={
                'num_nodes': c.shape[0],
                'x': torch.empty(c.shape[0]),
            },
            cons__to__vals={'edge_index': torch.vstack(A_where),
                            'edge_attr': A[A_where][:, None]},
            x_solution=x,
            duals=torch.from_numpy(duals).float(),
            obj_solution=c.dot(x),
            q=c,
            b=b,
            heur_idx=torch.from_numpy(heur_idx).long(),
            inactive_idx=torch.from_numpy(inactive_idx).long(),
        )
        success_cnt += 1
        graphs.append(data)

    if len(graphs) >= 1000 or success_cnt == num:
        torch.save(Batch.from_data_list(graphs), f'{root}/processed/batch{pkg_idx}.pt')
        pkg_idx += 1
        graphs = []

    if success_cnt >= num:
        break

    pbar.set_postfix({'suc': success_cnt, 'inactive_heur_acc': inactive_heur_acc})

In [None]:
from data.dataset import LPDataset

In [None]:
ds = LPDataset(root, 'test')

In [None]:
data = ds[0]

In [None]:
from transforms.lp_preserve import (DropInactiveConstraint, OracleDropInactiveConstraint, OracleDropIdleVariable,
                                    AddRedundantConstraint,
                                    ScaleConstraint, ScaleCoordinate,
                                    AddSubOrthogonalConstraint,
                                    AddDumbVariables, OracleBiasProblem)

In [None]:
from transforms.lp_preserve import ComboPreservedTransforms, ComboInterpolateTransforms

In [None]:
tf = ComboInterpolateTransforms({'DropInactiveConstraint': 0.0,
                               'OracleDropInactiveConstraint': 0.,
                               'OracleDropIdleVariable': 0.9,
                               # 'OracleBiasProblem': 1.,
                               'ScaleConstraint': 1.,
                               'ScaleCoordinate': 1.,
                               'AddRedundantConstraint': 0.5,
                               'AddDumbVariables': 0.5}, 10)

In [None]:
from utils.evaluation import recover_lp_from_data

In [None]:
A,c,b,*_ = recover_lp_from_data(data)
solution, duals = gurobi_solve_lp(A, b, c)

In [None]:
c.dot(solution)

In [None]:
for _ in range(100):
    d1 = tf(data)
    A, c, b, *_ = recover_lp_from_data(d1, np.float64)
    solution, duals = gurobi_solve_lp(A, b, c)
    obj = c.dot(solution)
    transformed_obj = d1.obj_solution
    print(obj, transformed_obj)