### Imports

In [None]:
import os
import pathlib

import codetiming
import numpy as np
import torch


device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Loaded torch. Using *{device}* device.")

### Load dataset

In [None]:
from torch_geometric.loader import DataLoader
from my_graphs_dataset import GraphDataset
from algebraic_connectivity_dataset import ConnectivityDataset

selected_graph_sizes = {
        3: -1,
        4: -1,
        5: -1,
        6: -1,
        7: -1,
        8: -1,
        # 9:  100000,
        # 10: 100000
    }
dataset_loader = GraphDataset(selection=selected_graph_sizes)
dataset = ConnectivityDataset(pathlib.Path(os.getcwd()) / "Dataset", dataset_loader)

# General information
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')

# Shuffle, split, batch and load data.
torch.manual_seed(12345)

dataset = dataset.shuffle()

train_size = round(0.8 * len(dataset))

train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

print()
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

# TODO: Batch size?
train_loader = DataLoader(train_dataset, batch_size=10000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2500, shuffle=False)

print()
print("Batches:")
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

### Models

#### Basic GCN

In [None]:
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = self.lin(x)

        return x

model = GCN(hidden_channels=20).to(device)
print(model)


### Train loop

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.L1Loss()


def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)  # Move to CUDA if available.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out.squeeze(), data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.


def test(loader):
    model.eval()

    with torch.no_grad():
        for data in loader:  # Iterate in batches over the training/test dataset.
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out.squeeze(), data.y).item()  # Compute the loss.
    return loss


num_epochs = 100
train_losses = np.zeros(num_epochs)
test_losses = np.zeros(num_epochs)

with codetiming.Timer():
    for epoch in range(1, num_epochs + 1):
        train()
        train_loss = test(train_loader)
        test_loss = test(test_loader)

        train_losses[epoch-1] = train_loss
        test_losses[epoch-1] = test_loss
        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

### Plot the training curve

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

### Save the model

In [None]:
# torch.save(model.state_dict(), "model.pth")
# print("Saved PyTorch Model State to model.pth")


### Make predictions with loaded model

In [None]:
# model = NeuralNetwork().to(device)
# model.load_state_dict(torch.load("model.pth"))

# classes = [
#     "T-shirt/top",
#     "Trouser",
#     "Pullover",
#     "Dress",
#     "Coat",
#     "Sandal",
#     "Shirt",
#     "Sneaker",
#     "Bag",
#     "Ankle boot",
# ]

# model.eval()
# x, y = test_data[0][0], test_data[0][1]
# with torch.no_grad():
#     x = x.to(device)
#     pred = model(x)
#     predicted, actual = classes[pred[0].argmax(0)], classes[y]
#     print(f'Predicted: "{predicted}", Actual: "{actual}"')