In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNNModel(torch.nn.Module):
    def __init__(self):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 16)
        self.fc1 = torch.nn.Linear(16, 16)  # First fully connected layer
        self.fc2 = torch.nn.Linear(16, 9) # Output 9 scores for 9 moves Output layer)


    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Handle empty edge indices
        if edge_index.numel() == 0:
            x = torch.zeros_like(x)  # Set node features to zero if no edges
        else:
            x = F.relu(self.conv1(x, edge_index))
            x = F.relu(self.conv2(x, edge_index))

        x = torch.mean(x, dim=0, keepdim=True)

        # Ensure the shape is [1, 16] before feeding to fully connected layers
        if x.shape[1] != 16:
            x = torch.zeros(1, 16)  # If the shape isn't right, set it to zeros with the correct shape

        print("Reshaped x shape:", x.shape)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


In [None]:
from torch_geometric.loader import DataLoader

# Split data into train and test sets
train_size = int(0.8 * len(graph_data_list))
train_data = graph_data_list[:train_size]
test_data = graph_data_list[train_size:]

train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)


model = GNNModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()


def train():
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data)
        target = data.y.long()
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        print(f'Train Loss: {loss.item()}')


def evaluate():
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            out = model(data)
            _, predicted = torch.max(out, 1)
            total += data.y.size(0)
            correct += (predicted == data.y.view(-1)).sum().item()
        print(f'Accuracy: {correct / total * 100:.2f}%')


train()
evaluate()


Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.2294745445251465
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.178105592727661
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.1311280727386475
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.074622392654419
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.0217881202697754
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.0303256511688232
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.9988377094268799
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.8722422122955322
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.9274756908416748
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.1504065990448
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.735811710357666
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.6469234228134155
Reshaped x shape: torch.Size([1, 16])
Train Loss: 2.1062493324279785
Reshaped x shape: torch.Size([1, 16])
Train Loss: 1.7335233688354492
Reshaped x shape: torch.Size([1, 16])
Tr

In [None]:
torch.save(model.state_dict(), 'gnn_model.pth')

print("Model saved as gnn_model.pth")

Model saved as gnn_model.pth


In [None]:
import numpy as np
import ipywidgets as widgets
from IPython.display import display
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

In [None]:
class GNNModel(torch.nn.Module):
    def __init__(self):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 16)
        self.fc1 = torch.nn.Linear(16, 16)
        self.fc2 = torch.nn.Linear(16, 9)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = torch.mean(x, dim=0)
        x = F.relu(self.fc1(x))   # Pass through first fully connected layer
        x = self.fc2(x)           # Pass through second fully connected layer
        return x