In [7]:
import torch
import numpy as np
from models.training import easyTrainer, weights_to_dataset
from models.nODE import nODE, make_nODE_from_parameters
import matplotlib.pyplot as plt
from torch_geometric.data import Data
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

In [8]:
adj_list = []
adj_list.append(np.array([[0.,0.],[0.,0.]]))
adj_list.append(np.array([[0,1.],[0,0]]))
adj_list.append(np.array([[0,-1.],[0,0]]))
adj_list.append(np.array([[0,1.],[1.,0]]))
adj_list.append(np.array([[0,1.],[-1.,0]]))
adj_list.append(np.array([[0,-1.],[-1.,0]]))

In [97]:
def generate_data(adj,y,ndata=2):
    ODE_dim = 2
    Gamma = np.array([-1., -1.])
    integration_time = 1
    Win = adj
    Wout = np.array([[2., 0], [0, 2.]])
    bin = np.array([[2.], [2.]])
    bout = np.array([[2.], [2.]])

    train_data, test_data = weights_to_dataset(integration_time, Gamma, Win=Win, bin=bin, Wout=Wout, bout=bout, batch_size = ndata)

    edge_index = torch.tensor([[0, 1],
                           [1, 0]], dtype=torch.long)
    
    for i, (x1, x_batch) in enumerate(train_data):

        x = torch.tensor([[x_batch[0][0]],[x_batch[0][1]]], dtype=torch.float)
        print(x)

        y = torch.tensor([y])

        data = Data(x= x, edge_index=edge_index,y=y)

    return data

data_list = []
for graph_ind in [0,1,2,3,4,5]:
    for i in range(0,100):
        ind = 0
        data = generate_data(adj_list[graph_ind],graph_ind)
        data_list.append(data)
loader = DataLoader(data_list, batch_size=32, shuffle=True)

test_data_list = []
for graph_ind in [0,1,2,3,4,5]:
    for i in range(0,20):
        ind = 0
        data = generate_data(adj_list[graph_ind],graph_ind)
        test_data_list.append(data)
test_loader = DataLoader(data_list, batch_size=32, shuffle=True)

for data in loader:
    print(data)

tensor([[3.5238],
        [3.4998]])
tensor([[2.6306],
        [4.3126]])
tensor([[4.2947],
        [3.6514]])
tensor([[3.7095],
        [3.5534]])
tensor([[2.6554],
        [3.7597]])
tensor([[3.6674],
        [2.6982]])
tensor([[3.9313],
        [2.6215]])
tensor([[3.9487],
        [3.5241]])
tensor([[3.2754],
        [2.9827]])
tensor([[3.6479],
        [3.6681]])
tensor([[3.9232],
        [2.7520]])
tensor([[2.5128],
        [3.2993]])
tensor([[2.8682],
        [2.7866]])
tensor([[3.2961],
        [4.2220]])
tensor([[2.5801],
        [3.6323]])
tensor([[3.6695],
        [3.4791]])
tensor([[2.6563],
        [3.9072]])
tensor([[3.9024],
        [4.1360]])
tensor([[3.9935],
        [3.0322]])
tensor([[2.7724],
        [2.7108]])
tensor([[2.6650],
        [4.2479]])
tensor([[3.1642],
        [2.8029]])
tensor([[2.7922],
        [4.1156]])
tensor([[3.3967],
        [3.6958]])
tensor([[3.2327],
        [2.8276]])
tensor([[3.8031],
        [3.5974]])
tensor([[2.8145],
        [3.8094]])
t

In [98]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(-1, 16)
        self.conv2 = GCNConv(16, 6)
        self.linear = torch.nn.Linear(6, 6)  # Linear layer to output 6x1 vector

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        # Here we aggregate the output from all nodes
        x = x.mean(dim=0, keepdim=True)
        x = self.linear(x)

        return F.log_softmax(x, dim=1)

In [116]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data_list[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    loss = 0
    for data_point in data_list:
        data = data_point.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss + F.nll_loss(out, data.y)
    loss.backward()
    optimizer.step()
    print(loss)

tensor(1089.5558, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1077.3984, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1069.6880, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1064.8035, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1061.6661, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1059.6599, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1057.8455, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1055.5507, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1052.5637, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1049.0549, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1045.2450, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1041.2072, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1036.8551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1032.1329, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1027.0415, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1021.6819, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1016.0675, device='cuda:0', grad_fn=<AddBackward0

In [127]:
correct = 0
total = 0
for data_point in test_data_list:
    data = data_point.to(device)
    out = model(data)
    loss = F.nll_loss(out, data.y)
    pred = torch.argmax(out)
    if pred == data.y[0]:
        correct = correct + 1
    total = total + 1

print(correct/total)

0.35
