**[Graph metanetworks for processing diverse neural architectures](https://arxiv.org/abs/2312.04501)**

our setup: given noisy vector $z_t$, predict the denoised version $\hat x$ and compute loss against gt $x$. 

In [2]:
%load_ext autoreload
%autoreload 2

In [7]:
import sys
sys.path.append('/home/cizinsky/x-to-nif/')

from helpers.dataset import ParamDataset
import helpers.gmn.model_arch_graph as mag
from helpers.gmn.graph_models import EdgeMPNN
from helpers.utils import flatten_params, unflatten_params, get_pretrained_sequential

from omegaconf import OmegaConf

import torch
import torch.nn as nn

In [3]:
cfg = OmegaConf.load('/home/cizinsky/x-to-nif/configs/train.yaml')
cfg.username = 'cizinsky'

In [4]:
dataset = ParamDataset(cfg)

In [17]:
sample = dataset[0]
w1, b1, w2 = sample['raw_weights']
w1.shape, b1.shape, w2.shape

(torch.Size([96, 48]), torch.Size([96]), torch.Size([12, 96]))

In [19]:
# Model definition
mlp = nn.Sequential(
    nn.Linear(48, 96, bias=True),
    nn.ReLU(),
    nn.Linear(96, 12, bias=False),
)

# Load weights into model
with torch.no_grad():
    mlp[0].weight.copy_(w1)
    mlp[0].bias.copy_(b1)
    mlp[2].weight.copy_(w2)

# Forward pass
x = torch.randn(1, 48)
y = mlp(x)
y.shape


torch.Size([1, 12])

In [15]:
mlp[0].weight.shape, mlp[0].bias.shape, mlp[2].weight.shape

(torch.Size([96, 48]), torch.Size([96]), torch.Size([12, 96]))

In [7]:
arch = mag.sequential_to_arch(mlp)
type(arch), len(arch), arch[0], arch[1]

(list,
 2,
 [torch.nn.modules.linear.Linear, Parameter containing:
  tensor([[ 0.2215,  0.0023, -0.0210,  ...,  0.0024, -0.0186, -0.0290],
          [-0.0213, -0.0966,  0.0135,  ..., -0.0765, -0.1125,  0.0733],
          [-0.1921,  0.0981,  0.0883,  ...,  0.0058, -0.1396,  0.1041],
          ...,
          [ 0.1184,  0.0997,  0.1147,  ...,  0.0259, -0.0215, -0.0824],
          [-0.1287,  0.0377,  0.1366,  ...,  0.0547,  0.0637, -0.1206],
          [ 0.2087, -0.0422,  0.0703,  ..., -0.0186, -0.0054,  0.0456]],
         requires_grad=True), Parameter containing:
  tensor([ 0.1093, -0.1692,  0.0345,  0.0764, -0.0247,  0.1140,  0.0768, -0.0306,
           0.0280,  0.0743, -0.0265,  0.0013, -0.1078,  0.0663, -0.1150, -0.0128,
          -0.1199,  0.0737,  0.0440,  0.0583, -0.1369, -0.0431, -0.1172, -0.0818,
           0.0882, -0.0273, -0.1061, -0.0242, -0.0260,  0.0494, -0.1348, -0.0375,
           0.0815,  0.0613, -0.1269, -0.0346, -0.0781,  0.0851,  0.0308, -0.0231,
           0.0229, -0.1

48 input neurons, 96 hidden neurons, 12 output neurons and 1 bias neuron = 157 neurons in total -> 157 nodes in the graph

In [8]:
graph = mag.arch_to_graph(arch)
x, edge_index, edge_attr = graph

In [9]:
x.shape

torch.Size([157, 3])

In [10]:
len(x), x[4], x[-1]

(157, tensor([0, 4, 0]), tensor([ 2, 11,  0]))

In [11]:
edge_index.shape

torch.Size([2, 5856])

In [12]:
type(edge_attr), edge_attr.shape, edge_attr[-1]

(torch.Tensor,
 torch.Size([5856, 6]),
 tensor([ 0.0488,  1.0000,  0.0000, -1.0000, -1.0000, -1.0000],
        grad_fn=<SelectBackward0>))

In [13]:
gnn = EdgeMPNN(node_in_dim=3, edge_in_dim=6, hidden_dim=96, node_out_dim=4, edge_out_dim=4, num_layers=4, dropout=0.0, reduce='mean')
x_out, edge_attr_out = gnn(x, edge_index, edge_attr)

In [14]:
x_out.shape, edge_attr_out.shape

(torch.Size([157, 4]), torch.Size([5856, 4]))

can we do batch training?

In [11]:
# Load config
cfg = OmegaConf.load('/home/cizinsky/x-to-nif/configs/graph_conditional_gram_baseline.yaml')
cfg.username = 'cizinsky'
dataset = ParamDataset(cfg)
samples = [dataset[i] for i in range(5)]

# GNN definition
gnn = EdgeMPNN(node_in_dim=3, edge_in_dim=6, hidden_dim=96, node_out_dim=4, edge_out_dim=4, num_layers=4, dropout=0.0, reduce='mean')



In [12]:
graphs = []
for sample in samples:
    flat_weights = sample['weights']
    w1, b1, w2 = unflatten_params(flat_weights, 48, 96, 12)

    # Load weights into model
    mlp = get_pretrained_sequential(w1, b1, w2)
    arch = mag.sequential_to_arch(mlp)
    graph = mag.arch_to_graph(arch)
    graphs.append(graph)
