In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
from scipy.sparse.linalg import svds

In [10]:
from data.dataset import LPDataset
import torch
import numpy as np

dataset = LPDataset('datasets/svd_mid_fac')

In [None]:
import torch
import torch.nn as nn

# RBF Layer

class RBF(nn.Module):
    """
    Transforms incoming data using a given radial basis function:
    u_{i} = rbf(||x - c_{i}|| / s_{i})

    Arguments:
        in_features: size of each input sample
        out_features: size of each output sample

    Shape:
        - Input: (N, in_features) where N is an arbitrary batch size
        - Output: (N, out_features) where N is an arbitrary batch size

    Attributes:
        centres: the learnable centres of shape (out_features, in_features).
            The values are initialised from a standard normal distribution.
            Normalising inputs to have mean 0 and standard deviation 1 is
            recommended.
        
        log_sigmas: logarithm of the learnable scaling factors of shape (out_features).
        
        basis_func: the radial basis function used to transform the scaled
            distances.
    """

    def __init__(self, in_features, out_features, basis_func):
        super(RBF, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
        self.log_sigmas = nn.Parameter(torch.Tensor(out_features))
        self.basis_func = basis_func
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.centres, 0, 1)
        nn.init.constant_(self.log_sigmas, 0)

    def forward(self, input):
        size = (input.size(0), self.out_features, self.in_features)
        x = input.unsqueeze(1).expand(size)
        c = self.centres.unsqueeze(0).expand(size)
        distances = (x - c).pow(2).sum(-1).pow(0.5) / torch.exp(self.log_sigmas).unsqueeze(0)
        return self.basis_func(distances)



# RBFs

def gaussian(alpha):
    phi = torch.exp(-1*alpha.pow(2))
    return phi

In [None]:
from tqdm import tqdm

In [None]:
new_g = []
for g in tqdm(dataset):
    a = SparseTensor(col=g.A_col, row=g.A_row, value=g.A_val).to_dense().numpy()
    u, s, d = svds(a, k=5)
    g['cons'].x = torch.from_numpy(u).float()
    g['vals'].x = torch.from_numpy(d.T.copy()).float()
    del g['obj'], g[('vals', 'to', 'obj')], g[('cons', 'to', 'obj')], g[('obj', 'to', 'cons')], g[('obj', 'to', 'vals')]
    new_g.append(g)

In [None]:
from torch_geometric.data.collate import collate

data, slices, _ = collate(
            new_g[0].__class__,
            data_list=new_g,
            increment=False,
            add_batch=False,
        )

In [None]:
torch.save((data, slices), 'data.pt')

In [None]:
from data.dataset import LPDataset
import torch
import numpy as np

dataset = LPDataset('datasets/svd_mid_fac')

In [11]:
g = dataset[1]

In [None]:
a = SparseTensor(col=g.A_col, row=g.A_row, value=g.A_val).to_dense()

In [None]:
a.unique()

In [None]:
g.c

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from models.hetero_gnn import TripartiteHeteroGNN
from models.cycle_model import CycleGNN

In [None]:
gnn = TripartiteHeteroGNN(conv='gcnconv',
                            hid_dim=256,
                            num_conv_layers=4,
                            num_pred_layers=2,
                            num_mlp_layers=2,
                            dropout=0.,
                            norm='graphnorm',
                            use_res=False,
                            conv_sequence='cov').to(device)

In [None]:
model = CycleGNN(8, 32, gnn).to(device)

In [None]:
from data.collate_func import collate_fn_lp

In [None]:
from torch_sparse import SparseTensor

In [None]:
g = dataset[1]
batch = collate_fn_lp([g], device)
batch = batch.to(device)
_ = gnn(batch)

In [None]:
A = SparseTensor(row=g.A_row,
                 col=g.A_col,
                 value=g.A_val, is_sorted=True,
                 trust_data=True).to_dense()
b = g.b

In [None]:
(A @ batch.x_start - b).abs().max()

In [None]:
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.eval()

In [None]:
xs = model.evaluation(batch)