In [1]:
import torch 
import torch.nn as nn
import torch_geometric.nn as nng
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from src.simulations import SurfaceCodeSim
from src.graph import get_batch_of_graphs
from src.models import MWPMLoss
from tqdm import tqdm

##### Initialise settings

In [2]:
reps = 3
code_sz = 3
p = 1e-3
n_shots = 1000
sim = SurfaceCodeSim(reps, code_sz, p, n_shots, seed=1)
n_epochs = 10
n_batches = 1
factor = -0.1

##### Create a simple network

In [3]:
class SplitSyndromes(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, edges, edge_attr, detector_labels):

        node_range = torch.arange(0, detector_labels.shape[0])
        node_subset = node_range[detector_labels]
        
        valid_labels = torch.isin(edges, node_subset).sum(dim=0) == 2
        return edges[:, valid_labels], edge_attr[valid_labels, :]

class GATGNN(nn.Module):

    def __init__(self, n_heads=2, edge_dimensions=2):
        super().__init__()

        self.gat1 = nng.GATConv(
            5,
            16,
            heads=n_heads,
            concat=False,
            edge_dim=edge_dimensions,
            add_self_loops=False,
        )
        self.gat2 = nng.GATConv(
            16,
            32,
            heads=n_heads,
            concat=False,
            edge_dim=edge_dimensions,
            add_self_loops=False,
        )
        
        self.split_syndromes = SplitSyndromes()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, edges, edge_attr, detector_labels):
        # edge_attr.retain_grad()
        print(f"{edge_attr.is_leaf=}")
        print(f"{edge_attr.requires_grad=}")
        x, (edges, edge_attr) = self.gat1(
            x, edges, edge_attr, return_attention_weights=True
        )
        edge_attr.retain_grad()
        print(f"{edge_attr.is_leaf=}")
        print(f"{edge_attr.requires_grad=}")
        x = torch.nn.functional.relu(x, inplace=True)
        _, (edges, edge_attr) = self.gat2(
            x, edges, edge_attr, return_attention_weights=True
        )

        edges, edge_attr = self.split_syndromes(edges, edge_attr, detector_labels)
        edge_attr.retain_grad()
        # print(f"{edge_attr[:10, 1]=}")
        # edge_attr[:, 1] = self.sigmoid(edge_attr[:, 1])
        # print(f"{edge_attr[:10, 1]}")
        edge_attr = self.sigmoid(edge_attr)
        edge_attr.retain_grad()
        # print(edge_attr.requires_grad)
        print(edge_attr.grad)
        return edges, edge_attr

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

class GNNLayer(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, out_channels):
        super(GNNLayer, self).__init__(aggr='add')  # "Add" aggregation.
        self.linear = nn.Linear(in_channels + edge_attr_dim, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        # edge_attr has shape [E, edge_attr_dim]

        # Transform node feature matrix.
        x = self.linear(torch.cat([x, edge_attr], dim=1))

        # Propagate messages.
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j has shape [E, out_channels]
        return x_j

class GNNWithEdgeFeatures(nn.Module):
    def __init__(self, input_dim, edge_attr_dim, hidden_dim, edge_feat_dim):
        super(GNNWithEdgeFeatures, self).__init__()
        self.gnn_layer = GNNLayer(input_dim, edge_attr_dim, hidden_dim)
        self.edge_model = nn.Linear(hidden_dim, edge_feat_dim)

    def forward(self, x, edge_index, edge_attr):
        node_feats = self.gnn_layer(x, edge_index, edge_attr)
        edge_feats = self.edge_model(node_feats)
        return edge_feats

# Example usage:
input_dim = 32
edge_attr_dim = 8
hidden_dim = 64
edge_feat_dim = 2


##### Do a training loop

In [6]:
# model = GATGNN()
model = GNNWithEdgeFeatures(input_dim, edge_attr_dim, hidden_dim, edge_feat_dim)
model.train()
loss_fun = MWPMLoss.apply
optim = torch.optim.SGD(model.parameters(), lr=0.001)

# first_att = model.state_dict()["gat1.att_edge"]
# second_att = model.state_dict()["gat2.att_edge"]
for epoch in range(n_epochs):
    train_loss = 0
    epoch_n_graphs = 0
    # print(first_att - model.state_dict()["gat1.att"])
    # print(second_att - model.state_dict()["gat2.att"])
    for _ in range(n_batches):
        optim.zero_grad()
        syndromes, flips, n_trivial = sim.generate_syndromes(n_shots)
        x, edges, edge_attr, batch_labels, detector_labels = get_batch_of_graphs(syndromes, 20, code_sz)
        # print(edge_attr.is_leaf)
        edges, edge_attr = model(x, edges, edge_attr, detector_labels)
        node_range = torch.arange(0, x.shape[0])
        loss = loss_fun(
        edges,
        edge_attr,
        batch_labels,
        node_range,
        np.array(flips) * 1,
        factor
        )
        loss.backward()
        optim.step()
        # print(loss.grad)
        
        n_graphs = syndromes.shape[0]
        train_loss += loss.item() * n_graphs
        epoch_n_graphs += n_graphs
    train_loss /= epoch_n_graphs
    
    # print(train_loss)

TypeError: GNNWithEdgeFeatures.forward() takes 4 positional arguments but 5 were given

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, scatter

class GNNLayer(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, out_channels):
        super(GNNLayer, self).__init__(aggr='add')  # "Add" aggregation.
        self.linear = nn.Linear(in_channels + edge_attr_dim, out_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        # edge_attr has shape [E, edge_attr_dim]
        # batch has shape [N]

        # Transform node feature matrix.
        x = self.linear(torch.cat([x, edge_attr], dim=1))

        # Propagate messages.
        return self.propagate(edge_index, x=x, size=(x.size(0), x.size(0)), batch=batch)

    def message(self, x_j):
        # x_j has shape [E, out_channels]
        return x_j

class EdgeModel(nn.Module):
    def __init__(self, in_channels):
        super(EdgeModel, self).__init__()
        self.edge_predictor = nn.Linear(in_channels, 2)  # Output dimension 2

    def forward(self, x):
        return self.edge_predictor(x)

class GNNWithEdgeFeatures(nn.Module):
    def __init__(self, input_dim, edge_attr_dim, hidden_dim):
        super(GNNWithEdgeFeatures, self).__init__()
        self.gnn_layer = GNNLayer(input_dim, edge_attr_dim, hidden_dim)
        self.edge_model = EdgeModel(hidden_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        node_feats = self.gnn_layer(x, edge_index, edge_attr, batch)
        edge_feats = self.edge_model(node_feats)
        return edge_feats

# Example usage:
input_dim = 32
edge_attr_dim = 2
hidden_dim = 64

# Create example input data
num_nodes = 100
num_edges = 200
num_graphs = 5
x = torch.randn(num_nodes, input_dim)
edge_index = torch.randint(0, num_nodes, (2, num_edges))
edge_attr = torch.randn(num_edges, edge_attr_dim)
batch = torch.randint(0, num_graphs, (num_nodes,))

model = GNNWithEdgeFeatures(input_dim, edge_attr_dim, hidden_dim)
output = model(x, edge_index, edge_attr, batch)
print(output.shape)  # Output shape: [num_edges, 2]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 100 but got size 200 for tensor number 1 in the list.