In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [92]:
from data.dataset import LPDataset
import torch
import numpy as np
from scipy.sparse.linalg import svds
from tqdm import tqdm
from torch_sparse import SparseTensor

dataset = LPDataset('datasets/setc_eq')

In [ ]:
new_g = []
for g in tqdm(dataset):
    a = SparseTensor(col=g.A_col, row=g.A_row, value=g.A_val).to_dense().numpy()
    u, s, d = svds(a, k=5)
    g['cons'].x = torch.from_numpy(u).float()
    g['vals'].x = torch.from_numpy(d.T.copy()).float()
    del g['obj'], g[('vals', 'to', 'obj')], g[('cons', 'to', 'obj')], g[('obj', 'to', 'cons')], g[('obj', 'to', 'vals')]
    new_g.append(g)


In [None]:
from torch_geometric.data.collate import collate

data, slices, _ = collate(
            new_g[0].__class__,
            data_list=new_g,
            increment=False,
            add_batch=False,
        )

In [ ]:
torch.save((data, slices), 'data.pt')

In [60]:
ds = torch.load('/Users/qianchendi/Downloads/large_setc/processed/batch0.pt')

In [93]:
g = dataset[0]

In [94]:
g

HeteroData(
  x_solution=[46],
  x_feasible=[46],
  obj_solution=[1],
  c=[46],
  b=[31],
  A_row=[82],
  A_col=[82],
  A_val=[82],
  proj_matrix=[2116],
  cons={ x=[31, 2] },
  vals={ x=[46, 2] },
  obj={ x=[1, 2] },
  (cons, to, vals)={
    edge_index=[2, 82],
    edge_attr=[82, 1],
  },
  (vals, to, obj)={
    edge_index=[2, 46],
    edge_attr=[46, 1],
  },
  (cons, to, obj)={
    edge_index=[2, 31],
    edge_attr=[31, 1],
  }
)

In [112]:
vals = g['cons', 'to', 'vals'].edge_attr.squeeze()
vals[vals != 0.] = 1.

In [209]:
A = SparseTensor(row=g['cons', 'to', 'vals'].edge_index[0],
                         col=g['cons', 'to', 'vals'].edge_index[1],
                         value=vals,
                         is_sorted=True, trust_data=True)

In [89]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [81]:
from models.hetero_gnn import BipartiteHeteroGNN
from models.cycle_model import CycleGNN

In [98]:
gnn = BipartiteHeteroGNN(conv='gcnconv',
                         hid_dim=128,
                         num_conv_layers=18,
                         num_pred_layers=2,
                         num_mlp_layers=2,
                         norm='graphnorm')

In [None]:
model = CycleGNN(8, 32, gnn).to(device)

In [86]:
from data.collate_func import collate_fn_lp_bi

In [95]:
g = dataset[1]
batch = collate_fn_lp_bi([g], device)
batch = batch.to(device)
_ = gnn(batch)

In [99]:
gnn(batch)

tensor([ 0.0671,  0.1391,  0.0421,  0.0383,  0.1096,  0.1109,  0.1260,  0.0447,
         0.0438,  0.0730,  0.2452,  0.0170,  0.1443,  0.0845,  0.0517, -0.0920,
        -0.0650, -0.0288, -0.0592, -0.0203, -0.0484, -0.0271, -0.0754, -0.0108,
        -0.0155, -0.0404, -0.0895, -0.0228,  0.0087, -0.0519, -0.0470,  0.0087,
         0.0040, -0.0273,  0.0062,  0.0126,  0.0367,  0.0381,  0.0300, -0.0236,
        -0.0520, -0.0261,  0.0113, -0.0321,  0.0005,  0.0310,  0.0426,  0.0098],
       grad_fn=<SqueezeBackward0>)

In [100]:
gnn.obj_encoder(g.c[:, None])

tensor([[-0.1916, -0.2368, -0.3732,  ...,  0.4810, -0.0725, -0.0372],
        [-0.1953, -0.2352, -0.3730,  ...,  0.4907, -0.0721, -0.0389],
        [-0.0693, -0.2285, -0.3954,  ...,  0.2495, -0.0195,  0.0701],
        ...,
        [-0.0274, -0.2057, -0.3914,  ...,  0.2352, -0.0187,  0.0756],
        [-0.0274, -0.2057, -0.3914,  ...,  0.2352, -0.0187,  0.0756],
        [-0.0274, -0.2057, -0.3914,  ...,  0.2352, -0.0187,  0.0756]],
       grad_fn=<AddmmBackward0>)

In [None]:
(A @ batch.x_start - b).abs().max()

In [None]:
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval()

In [None]:
xs = model.evaluation(batch)