# Node 0 ---[ R = 100Ω, V = 10V ]--- Node 1 ---[ R = 200Ω ]--- Node 2

import torch
from torch_geometric.data import Data

# 3 nodes: node 0 (ground), node 1, node 2
# 2 resistors and a voltage source between nodes

# Edge index [2, num_edges]
edge_index = torch.tensor([
    [0, 1, 1],
    [1, 0, 2]  # undirected or bidirectional
], dtype=torch.long)

# Edge features: resistance (Ohms), voltage (V)
# Let's say: R01 = 100Ω, R12 = 200Ω, V01 = 10V (voltage source from 0 to 1)
edge_attr = torch.tensor([
    [100.0, 10.0],   # edge from 0 to 1
    [100.0, -10.0],  # edge from 1 to 0
    [200.0, 0.0]     # edge from 1 to 2 (just a resistor)
], dtype=torch.float)

# Node features (optional, like initial voltage guesses)
x = torch.zeros((3, 1))  # 3 nodes, 1 feature each (voltage guess)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [3]:
from torch_geometric.nn import MessagePassing
from torch.nn import Linear

class CircuitGNN(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')  # sum incoming currents
        self.linear = Linear(2, 1)  # process edge attributes

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        resistance, voltage = edge_attr[:, 0:1], edge_attr[:, 1:2]
        
        # Ohm's Law: I = (V_j - V_i + V_source) / R
        current = (x_j - x_i + voltage) / resistance
        
        return current  # interpreted as net current contribution

    def update(self, aggr_out, x):
        # Simple model: voltage update proportional to incoming current
        return x + 0.1 * aggr_out  # small step, like gradient descent

In [4]:
model = CircuitGNN()

x = data.x.clone().requires_grad_(True)
optimizer = torch.optim.Adam([x], lr=0.01)

# Suppose we want to match known voltages at nodes (e.g., node 0 = 0V, node 1 = 10V)
target = torch.tensor([[0.0], [10.0], [7.5]])  # target voltages (example)
mask = torch.tensor([1.0, 1.0, 1.0]).unsqueeze(-1)

for epoch in range(200):
    optimizer.zero_grad()
    out = model(x, data.edge_index, data.edge_attr)
    loss = ((out - target) ** 2 * mask).mean()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')


Epoch 0, Loss: 52.0167
Epoch 20, Loss: 49.7157
Epoch 40, Loss: 47.4832
Epoch 60, Loss: 45.3256
Epoch 80, Loss: 43.2435
Epoch 100, Loss: 41.2358
Epoch 120, Loss: 39.3007
Epoch 140, Loss: 37.4367
Epoch 160, Loss: 35.6419
Epoch 180, Loss: 33.9148
