In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [3]:
from data.dataset import LPDataset

dataset = LPDataset('./mis_ineq')
g = dataset[1]

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
from models.hetero_gnn import TripartiteHeteroGNN

In [6]:
model = TripartiteHeteroGNN(conv='gcnconv',
                            hid_dim=128,
                            num_conv_layers=4,
                            num_pred_layers=2,
                            num_mlp_layers=2,
                            dropout=0.,
                            norm='graphnorm',
                            use_res=False,
                            conv_sequence='cov').to(device)

In [7]:
from data.utils import collate_fn_lp

In [8]:
batch = collate_fn_lp([g])

batch.x_start = torch.rand(batch.x_dict['vals'].shape[0])

batch = batch.to(device)
_ = model(batch)

In [9]:
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval()

TripartiteHeteroGNN(
  (encoder): ModuleDict(
    (vals): MLP(-1, 128, 128)
    (cons): MLP(-1, 128, 128)
    (obj): MLP(-1, 128, 128)
  )
  (start_pos_encoder): MLP(-1, 128, 128)
  (gcns): ModuleList(
    (0-3): 4 x HeteroConv(num_relations=6)
  )
  (pred_x): MLP(-1, 128, 1)
)

In [10]:
from torch_sparse import SparseTensor

A = SparseTensor(row=g.A_row,
                 col=g.A_col,
                 value=g.A_val, is_sorted=True,
                 trust_data=True).to_dense().numpy()
b = g.b.numpy()
c = g.c.numpy()

In [11]:
batch.x_start = torch.rand(batch.x_dict['vals'].shape[0])

# direction
x_direction = batch.x_solution - batch.x_start
batch.x_label = x_direction / x_direction.abs().max() + 1.e-7

In [12]:
grad_x = model(batch)

In [13]:
from trainer import Trainer

In [14]:
trainer = Trainer(device, 'l2', 1)

In [15]:
trainer.get_loss(grad_x, batch.x_label, batch['vals'].batch)

(tensor(0.0205, grad_fn=<MeanBackward0>), tensor(0.0237))

In [16]:
from scipy.linalg import pinv

A_inv = pinv(A)
A_inv_b = A_inv @ b

In [30]:
x = np.random.rand(batch.x_dict['vals'].shape[0])

pbar = range(20)
for iteration in pbar:
    print(f'obj: {c.dot(x)}')
    batch.x_start = torch.from_numpy(x).to(torch.float).to(device)

    with torch.no_grad():
        grad_x = model(batch)
    label = batch.x_solution - batch.x_start
    label /= label.abs().max() + 1.e-7

    print(f'eval: {trainer.get_loss(grad_x, label, batch["vals"].batch)}')

    grad_x = grad_x.cpu().numpy()

    alpha = 1.
    gradx_mask_lower = grad_x < 0
    if np.any(gradx_mask_lower):
        alpha = min(alpha, (-x[gradx_mask_lower] / grad_x[gradx_mask_lower]).min())
    gradx_mask_upper = grad_x > 0
    if np.any(gradx_mask_upper):
        alpha = min(alpha, ((1. - x[gradx_mask_upper]) / grad_x[gradx_mask_upper]).min())
    gradx_mask_upper = np.logical_and(grad_x > 0, x < A_inv_b)
    if np.any(gradx_mask_upper):
        alpha = min(alpha, ((A_inv_b[gradx_mask_upper] - x[gradx_mask_upper]) / grad_x[gradx_mask_upper]).min())

    x = x + alpha * grad_x

obj: -5.41257196852499
eval: (tensor(0.0072), tensor(0.0089))
obj: -6.069479041281429
eval: (tensor(0.0015), tensor(0.0030))
obj: -6.50649143314482
eval: (tensor(0.0040), tensor(0.0101))
obj: -6.506491435568028
eval: (tensor(0.0040), tensor(0.0101))
obj: -6.875293129232892
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: -6.87529314748343
eval: (tensor(0.0145), tensor(0.0489))
obj: 

In [92]:
from solver.linprog import linprog

In [110]:
linprog(c, A_ub=A, b_ub=b, bounds=(0, 1), options={'maxiter': 2}).x

array([0.18108461, 0.19329581, 0.19798445, 0.18446541, 0.50433196,
       0.50988041, 0.84183662, 0.5591294 , 0.55011799, 0.83908535,
       0.84605545, 0.84183662, 0.83908535])