In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from solver import ipm_overleaf, ipm_chapter14
import gzip
import pickle

from scipy.optimize import linprog
import numpy as np
import time

import torch
from tqdm import tqdm

In [None]:
with gzip.open(f"instances/setcover/instance_{0}.pkl.gz", "rb") as file:
    (A, b, c) = pickle.load(file)

sol = ipm_overleaf(c.numpy(), None, None, A.numpy(), b.numpy(), None, max_iter = 1000, tol = 1.e-9)

In [None]:
x, l, s =zip(*sol['xs'])
x = np.vstack(x)
l = np.vstack(l)
s = np.vstack(s)

gt_vals = torch.from_numpy(np.stack([x, s], axis=-1)).to(torch.float).cuda()
gt_cons = torch.from_numpy(l).to(torch.float).cuda()

gt_vals = torch.nn.functional.normalize(gt_vals, p=2.0, dim=2)
gt_cons = torch.nn.functional.normalize(gt_cons, p=2.0, dim=(0, 1))

gt_vals = gt_vals[torch.tensor([0, 5, 10, 15, 20, 25, 30, 34]).cuda()]
gt_cons = gt_cons[torch.tensor([0, 5, 10, 15, 20, 25, 30, 34]).cuda()]

In [None]:
from torch_geometric.data import HeteroData

torch.manual_seed(12)
data = HeteroData(
                  cons={'x': torch.randn(A.shape[0], 2, dtype=torch.float)},
                  vals={'x': torch.randn(A.shape[1], 2, dtype=torch.float)},
                  obj={'x': torch.rand(1, 2, dtype=torch.float)},
                  
                  # cons={'x': torch.cat([A.mean(1, keepdims=True), 
                  #                       A.std(1, keepdims=True)], dim=1)},
                  # vals={'x': torch.cat([A.mean(0, keepdims=True), 
                  #                       A.std(0, keepdims=True)], dim=0).T},
                  # obj={'x': torch.cat([c.mean(0, keepdims=True), 
                  #                       c.std(0, keepdims=True)], dim=0)[None]},
                  
                  cons__to__vals={'edge_index': torch.vstack(torch.where(A)),
                                'edge_weight': A[torch.where(A)][:, None]},
                  vals__to__cons={'edge_index': torch.vstack(torch.where(A.T)),
                                'edge_weight': A.T[torch.where(A.T)][:, None]},
                  vals__to__obj={'edge_index': torch.vstack([torch.arange(A.shape[1]), torch.zeros(A.shape[1], dtype=torch.long)]),
                               'edge_weight': torch.nn.functional.normalize(c, p=2.0, dim=0)[:, None]},
                  obj__to__vals={'edge_index': torch.vstack([torch.zeros(A.shape[1], dtype=torch.long), torch.arange(A.shape[1])]),
                               'edge_weight': torch.nn.functional.normalize(c, p=2.0, dim=0)[:, None]},
                  cons__to__obj={'edge_index': torch.vstack([torch.arange(A.shape[0]), torch.zeros(A.shape[0], dtype=torch.long)]),
                               'edge_weight': torch.nn.functional.normalize(b, p=2.0, dim=0)[:, None]},
                  obj__to__cons={'edge_index': torch.vstack([torch.zeros(A.shape[0], dtype=torch.long), torch.arange(A.shape[0])]),
                               'edge_weight': torch.nn.functional.normalize(b, p=2.0, dim=0)[:, None]})

In [None]:
data_homo = data.to_homogeneous()

In [None]:
del data_homo.edge_weight

In [None]:
(data_homo.node_type==0).sum()

In [None]:
from torch_geometric.transforms import AddRandomWalkPE, AddLaplacianEigenvectorPE

In [None]:
rw = AddRandomWalkPE(walk_length=3)(data_homo).random_walk_pe[:, 1:]

In [None]:
lap = AddLaplacianEigenvectorPE(k=5)(data_homo).laplacian_eigenvector_pe

In [None]:
data['cons'].x = lap[:314, :]
data['cons'].x = (data['cons'].x - data['cons'].x.mean(0)) / data['cons'].x.std(0)
data['vals'].x = lap[314:-1, :]
data['vals'].x = (data['vals'].x - data['vals'].x.mean(0)) / data['vals'].x.std(0)
data['obj'].x = lap[-1:, :]

In [None]:
from models import DeepHeteroGNN

In [None]:
model = DeepHeteroGNN(in_shape=5, 
                      hid_dim=256, 
                      num_layers=8, 
                      dropout=0., 
                      share_weight=False, 
                      use_norm=False, 
                      use_res=False).to('cuda')

In [None]:
data = data.to('cuda')

In [None]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

In [None]:
model.train()

pbar = tqdm(range(1000))
for epoch in pbar:
    optim.zero_grad()
    vals, cons = model(data.x_dict, data.edge_index_dict)
    loss = criterion(vals[..., 0], gt_vals[..., 0])
    loss.backward()
    optim.step()
    
    pbar.set_postfix({'epoch': epoch, 'loss': loss.item()})