In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
from data.dataset import LPDataset
import torch
import numpy as np

dataset = LPDataset('./mis_eq')
g = dataset[0]

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from models.hetero_gnn import TripartiteHeteroGNN

In [None]:
model = TripartiteHeteroGNN(conv='gcnconv',
                            hid_dim=128,
                            num_conv_layers=6,
                            num_pred_layers=2,
                            num_mlp_layers=2,
                            dropout=0.,
                            norm='graphnorm',
                            use_res=False,
                            conv_sequence='cov').to(device)

In [None]:
from data.utils import collate_fn_lp

In [None]:
batch = collate_fn_lp([g])

batch.x_start = torch.rand(batch.x_dict['vals'].shape[0])

batch = batch.to(device)
_ = model(batch)

In [None]:
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval()

In [None]:
from torch_sparse import SparseTensor

A = SparseTensor(row=g.A_row,
                 col=g.A_col,
                 value=g.A_val, is_sorted=True,
                 trust_data=True).to_dense().numpy()
b = g.b.numpy()
c = g.c.numpy()

In [None]:
batch.x_start = torch.rand(batch.x_dict['vals'].shape[0])

# direction
x_direction = batch.x_solution - batch.x_start
batch.x_label = x_direction / x_direction.abs().max() + 1.e-7

In [None]:
grad_x = model(batch)

In [None]:
from trainer import Trainer

In [None]:
trainer = Trainer(device, 'l2', 1)

In [None]:
trainer.get_loss(grad_x, batch.x_label, batch['vals'].batch)

In [None]:
from scipy.linalg import pinv

A_inv = pinv(A)
A_inv_b = A_inv @ b

In [None]:
x = np.random.rand(batch.x_dict['vals'].shape[0])

obj_trace = []

pbar = range(20)
for iteration in pbar:
    print(f'obj: {c.dot(x)}')
    obj_trace.append(c.dot(x))
    batch.x_start = torch.from_numpy(x).to(torch.float).to(device)

    with torch.no_grad():
        grad_x = model(batch)
    label = batch.x_solution - batch.x_start
    label /= label.abs().max() + 1.e-7

    print(f'eval: {trainer.get_loss(grad_x, label, batch["vals"].batch)}')

    grad_x = grad_x.cpu().numpy()[:, -1]

    alpha = 1.
    gradx_mask_lower = grad_x < 0
    if np.any(gradx_mask_lower):
        alpha = min(alpha, (-x[gradx_mask_lower] / grad_x[gradx_mask_lower]).min())
    gradx_mask_upper = grad_x > 0
    if np.any(gradx_mask_upper):
        alpha = min(alpha, ((1. - x[gradx_mask_upper]) / grad_x[gradx_mask_upper]).min())
    gradx_mask_upper = np.logical_and(grad_x > 0, x < A_inv_b)
    if np.any(gradx_mask_upper):
        alpha = min(alpha, ((A_inv_b[gradx_mask_upper] - x[gradx_mask_upper]) / grad_x[gradx_mask_upper]).min())

    x = x + alpha * grad_x

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(obj_trace)
plt.axhline(y=-9.506, color='r', linestyle='dashed')

In [None]:
from solver.linprog import linprog

In [None]:
xs = np.stack(linprog(c, A_ub=A, b_ub=b, bounds=(0, 1), options={'maxiter': 20}, callback=lambda res: res.x).intermediate, axis=0)

In [None]:
xs.dot(c)