In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
import numpy as np
import torch

In [3]:
from data.dataset import LPDataset

dataset = LPDataset('./mis')
g = dataset[0]

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [27]:
from models.hetero_gnn import TripartiteHeteroGNN

In [28]:
model = TripartiteHeteroGNN(conv='gcnconv',
                            hid_dim=128,
                            num_conv_layers=4,
                            num_pred_layers=2,
                            num_mlp_layers=2,
                            dropout=0.,
                            norm='graphnorm',
                            use_res=False,
                            conv_sequence='cov').to(device)

In [29]:
from data.utils import collate_fn_lp

In [34]:
batch = collate_fn_lp([g])

batch.x_start = torch.rand(batch.x_dict['vals'].shape[0])
batch.l_start = torch.rand(batch.x_dict['cons'].shape[0])
batch.s_start = torch.rand(batch.x_dict['vals'].shape[0])

batch = batch.to(device)
_ = model(batch)

In [33]:
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval()

TripartiteHeteroGNN(
  (encoder): ModuleDict(
    (vals): MLP(-1, 128, 128)
    (slack): MLP(-1, 128, 128)
    (cons): MLP(-1, 128, 128)
    (obj): MLP(-1, 128, 128)
  )
  (start_pos_encoder): ModuleDict(
    (x): MLP(-1, 128, 128)
    (l): MLP(-1, 128, 128)
    (s): MLP(-1, 128, 128)
  )
  (gcns): ModuleList(
    (0-3): 4 x HeteroConv(num_relations=10)
  )
  (pred_x): MLP(-1, 128, 1)
  (pred_l): MLP(-1, 128, 1)
  (pred_s): MLP(-1, 128, 1)
)

In [37]:
from torch_sparse import SparseTensor

A = SparseTensor(row=g.A_row,
                 col=g.A_col,
                 value=g.A_val, is_sorted=True,
                 trust_data=True).to_dense().numpy()
b = g.b.numpy()
c = g.c.numpy()

In [77]:
sol = ipm_overleaf(c, A, b, 'rand', 'cho', 1)
x, l, s = sol['x'], sol['lambd'], sol['s']

batch.x_start = torch.from_numpy(x).to(torch.float).to(device)
batch.l_start = torch.from_numpy(l).to(torch.float).to(device)
batch.s_start = torch.from_numpy(s).to(torch.float).to(device)

# direction
x_direction = batch.x_solution - batch.x_start
batch.x_label = x_direction / x_direction.abs().max() + 1.e-7
l_direction = batch.l_solution - batch.l_start
batch.l_label = l_direction / l_direction.abs().max() + 1.e-7
s_direction = batch.s_solution - batch.s_start
batch.s_label = s_direction / s_direction.abs().max() + 1.e-7

In [78]:
grad_x, grad_lambda, grad_s = model(batch)

In [79]:
from trainer import Trainer

In [80]:
trainer = Trainer(device, 'l2', 1)

In [82]:
trainer.get_loss(grad_x, batch.x_label, batch['vals'].batch)

(tensor(0.0459, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.1091, device='cuda:0'))

In [84]:
trainer.get_loss(grad_lambda, batch.l_label, batch['cons'].batch)

(tensor(0.1019, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.1817, device='cuda:0'))

In [85]:
trainer.get_loss(grad_s, batch.s_label, batch['vals'].batch)

(tensor(0.0383, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.0786, device='cuda:0'))

In [69]:
from solver.customized_solver import smart_start, mu, ipm_overleaf

In [75]:
sol = ipm_overleaf(c, A, b, 'rand', 'cho', 1)
x, l, s = sol['x'], sol['lambd'], sol['s']

pbar = range(20)
for iteration in pbar:
    batch.x_start = torch.from_numpy(x).to(torch.float).to(device)
    batch.l_start = torch.from_numpy(l).to(torch.float).to(device)
    batch.s_start = torch.from_numpy(s).to(torch.float).to(device)

    with torch.no_grad():
        grad_x, grad_lambda, grad_s = model(batch)
    grad_x = grad_x.cpu().numpy()
    grad_lambda = grad_lambda.cpu().numpy()
    grad_s = grad_s.cpu().numpy()

    alpha = 1.
    gradx_mask = grad_x < 0
    if np.any(gradx_mask):
        alpha = min(alpha, (-x[gradx_mask] / grad_x[gradx_mask]).min())
    grads_mask = grad_s < 0
    if np.any(grads_mask):
        alpha = min(alpha, (-s[grads_mask] / grad_s[grads_mask]).min())
    alpha_l = alpha_s = alpha_x = alpha

    x = x + alpha_x * grad_x
    l = l + alpha_l * grad_lambda
    s = s + alpha_s * grad_s

    print(c.dot(x))

-10.881860935684543
-11.434326656252395
-11.517511233311396
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494
-11.517511231817494


In [76]:
import warnings
from collections import namedtuple

import numpy as np
from scipy.linalg import LinAlgError
from scipy.linalg import cho_factor, cho_solve, lstsq
from scipy.sparse import spmatrix
from scipy.sparse.linalg import cg as sp_cg

sigma = 0.3
lin_solver = 'cho'

x, lambd, s = smart_start(A, b, c, 'rand')
_mu = mu(x, s)

pbar = range(20)
for iteration in pbar:
    try:
        s_inv = (s + 1.e-7) ** -1
        xs_inv = x * s_inv
        if isinstance(A, spmatrix):
            A_XS_inv = A.multiply(xs_inv[None])
        else:
            A_XS_inv = A * xs_inv[None]
        M = A_XS_inv @ A.transpose()
        rhs = b - A @ x + A_XS_inv @ c - M @ lambd - A @ s_inv * sigma * _mu

        # solve M @ x = rhs
        if isinstance(M, spmatrix):
            grad_lambda = sp_cg(M, rhs)[0]
        else:
            if lin_solver == 'cho':
                c_and_lower = cho_factor(M)
                grad_lambda = cho_solve(c_and_lower, rhs)
            elif lin_solver == 'lstsq':
                grad_lambda = lstsq(M, rhs)[0]
            else:
                raise NotImplementedError

        AT_lambda_plut_dlambda = A.transpose() @ (lambd + grad_lambda)
        grad_s = - AT_lambda_plut_dlambda - s + c
        grad_x = s_inv * sigma * _mu + xs_inv * (AT_lambda_plut_dlambda - c)

        alpha = 1.
        gradx_mask = grad_x < 0
        if np.any(gradx_mask):
            alpha = min(alpha, (-x[gradx_mask] / grad_x[gradx_mask]).min())
        grads_mask = grad_s < 0
        if np.any(grads_mask):
            alpha = min(alpha, (-s[grads_mask] / grad_s[grads_mask]).min())
        alpha_l = alpha_s = alpha_x = alpha

        x = x + alpha_x * grad_x
        lambd = lambd + alpha_l * grad_lambda
        s = s + alpha_s * grad_s
        _mu = mu(x, s)

        print(c.dot(x))

    except (LinAlgError, FloatingPointError, ValueError, ZeroDivisionError):
        warnings.warn(f'Instability occured at iter {iteration}, turning to lstsq')
        lin_solver = 'lstsq'

-7.310022502596555
-6.675549841136554
-6.347937096078859
-6.971093664843341
-7.238996728733692
-7.296180736528512
-7.319735853176443
-7.32925907873007
-7.332097043511313
-7.332961337408215
-7.333221624327923
-7.333299810510066
-7.3333232756146
-7.33333031576442
-7.333332427751247
-7.3333330649950526
-7.333333249535123
-7.333333306987967
-7.333333325367692
-7.333333335584685
