In [51]:
from solver import ipm_overleaf, ipm_chapter14
import gzip
import pickle

from scipy.optimize import linprog
import numpy as np
import time

import torch
from tqdm import tqdm

In [2]:
with gzip.open(f"instances/setcover/instance_{0}.pkl.gz", "rb") as file:
    (A, b, c) = pickle.load(file)

sol = ipm_overleaf(c.numpy(), None, None, A.numpy(), b.numpy(), None, max_iter = 1000, tol = 1.e-9)

  4%|██████▏                                                                                                                                                                         | 35/1000 [00:03<01:45,  9.13it/s, lin_system_steps=904]


In [18]:
from torch_geometric.datasets import DBLP

dataset = DBLP(root='./data')
data = dataset[0]

In [41]:
from torch_geometric.data import HeteroData

torch.manual_seed(0)
hid = 64
data = HeteroData(cons={'x': torch.randn(A.shape[0], hid)},
                  vals={'x': torch.randn(A.shape[1], hid)},
                  obj={'x': torch.rand(1, hid)},
                  
                  cons__to__vals={'edge_index': torch.vstack(torch.where(A)),
                                'edge_weight': A[torch.where(A)]},
                  vals__to__cons={'edge_index': torch.vstack(torch.where(A.T)),
                                'edge_weight': A.T[torch.where(A.T)]},
                  vals__to__obj={'edge_index': torch.vstack([torch.arange(A.shape[1]), torch.zeros(A.shape[1], dtype=torch.long)]),
                               'edge_weight': torch.nn.functional.normalize(c, p=2.0, dim=0)},
                  obj__to__vals={'edge_index': torch.vstack([torch.zeros(A.shape[1], dtype=torch.long), torch.arange(A.shape[1])]),
                               'edge_weight': torch.nn.functional.normalize(c, p=2.0, dim=0)},
                  cons__to__obj={'edge_index': torch.vstack([torch.arange(A.shape[0]), torch.zeros(A.shape[0], dtype=torch.long)]),
                               'edge_weight': torch.nn.functional.normalize(b, p=2.0, dim=0)},
                  obj__to__cons={'edge_index': torch.vstack([torch.zeros(A.shape[0], dtype=torch.long), torch.arange(A.shape[0])]),
                               'edge_weight': torch.nn.functional.normalize(b, p=2.0, dim=0)},
                  y=torch.from_numpy(np.vstack(sol['xs']).astype(np.float32)))

In [21]:
from torch_geometric.nn.conv import MessagePassing, HeteroConv

class MyMessagePassing(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        
        return out

    def message(self, x_j):
        return x_j

In [27]:
class MyLSTMConv(torch.nn.Module):
    def __init__(self, in_shape, hid_dim, steps):
        super().__init__()
        self.steps = steps
        self.conv = HeteroConv({
            ('cons', 'to', 'vals'): MyMessagePassing(),
            ('vals', 'to', 'cons'): MyMessagePassing(),}, aggr='sum')
        self.lin_vals = torch.nn.Linear(in_shape, hid_dim)
        self.lin_cons = torch.nn.Linear(in_shape, hid_dim)
        
        self.cell_vals = torch.nn.LSTMCell(hid_dim, hid_dim)
        self.cell_cons = torch.nn.LSTMCell(hid_dim, hid_dim)
        self.norm_vals = torch.nn.BatchNorm1d(hid_dim)
        self.norm_cons = torch.nn.BatchNorm1d(hid_dim)
        
        self.nn_vals =  torch.nn.Sequential(torch.nn.Linear(hid_dim, hid_dim), 
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(hid_dim, hid_dim))
        self.nn_cons =  torch.nn.Sequential(torch.nn.Linear(hid_dim, hid_dim), 
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(hid_dim, hid_dim))
        
        self.pred_vals = torch.nn.Sequential(torch.nn.Linear(hid_dim, hid_dim), 
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(hid_dim, 1))

    def forward(self, x_dict, edge_index_dict):
        x_dict['vals'] = torch.relu(self.lin_vals(x_dict['vals']))
        x_dict['cons'] = torch.relu(self.lin_vals(x_dict['cons']))
        
        h_val = x_dict['vals'].new_zeros(x_dict['vals'].shape)
        c_val = x_dict['vals'].new_zeros(x_dict['vals'].shape)
        
        h_con = x_dict['cons'].new_zeros(x_dict['cons'].shape)
        c_con = x_dict['cons'].new_zeros(x_dict['cons'].shape)
        
        
        hiddens = []
        for i in range(self.steps):
            x_dict = self.conv(x_dict, edge_index_dict)
            x_dict['vals'] = self.nn_vals(x_dict['vals'])
            x_dict['cons'] = self.nn_vals(x_dict['cons'])
            h_val, c_val = self.cell_vals(x_dict['vals'], (h_val, c_val))
            h_con, c_con = self.cell_vals(x_dict['cons'], (h_con, c_con))
            hiddens.append(h_val)
        
        hiddens = torch.stack(hiddens, dim=0)
        out = self.pred_vals(hiddens)
        return out.squeeze(-1)

In [59]:
model = MyLSTMConv(64, 128, 35).to('cuda')

In [60]:
data = data.to('cuda')

In [61]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

In [62]:
model.train()

pbar = tqdm(range(100))
for epoch in pbar:
    optim.zero_grad()
    pred = model(data.x_dict, data.edge_index_dict)
    loss = criterion(pred, data.y)
    loss.backward()
    optim.step()
    
    pbar.set_postfix({'epoch': epoch, 'loss': loss})

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 17.79it/s, epoch=99, loss=tensor(0.1419, device='cuda:0', grad_fn=<MseLossBackward0>)]
