#Setup

In [8]:
import numpy as np 
import matplotlib.pyplot as plt 
import tensorflow as tf 
from tensorflow import keras
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Dense, Input, Conv2D, UpSampling2D, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
import cv2
import os 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import random
import torch
from torch_geometric.nn import GATConv
from torch_geometric.utils import softmax
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import glob
from torch_geometric.data import Data, Dataset
import pickle
import torch.optim as optim

In [None]:
print("PyTorch has version {}".format(torch.__version__))
# Install torch geometric
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-geometric
!pip install ogb

### Load Data

In [3]:
class DiskGraphDataset(Dataset):
    def __init__(self, root, file_name_wildcard, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.file_name_wildcard = root+file_name_wildcard
        self.graph_filenames = glob.glob(self.file_name_wildcard)

    def len(self):
        return len(self.graph_filenames)

    def get(self, idx):
        # filename = self.file_name_wildcard.replace("*", str(idx))
        with open(self.graph_filenames[idx], 'rb') as f:
            graph = pickle.load(f)
        graph.y = labels_to_tensor(graph.y)
        return graph

In [4]:
from google.colab import drive
from torch_geometric.loader import DataLoader

drive.mount('/content/drive')
train_graphs = DiskGraphDataset("/content/drive/My Drive/Dataset/Graphs/wm-nowm/graphs_filter_3/", "train_graph_*.pkl")
test_graphs = DiskGraphDataset("/content/drive/My Drive/Dataset/Graphs/wm-nowm/graphs_filter_3/", "test_graph_*.pkl")
train_loader = DataLoader(train_graphs, batch_size=16, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=16, shuffle=False)

Mounted at /content/drive


### Models

In [7]:
# Custom GATConv layer with separate message passing strategies for central and neighbor nodes
class CustomGATConv(GATConv):
    def __init__(self, *args, **kwargs):
        super(CustomGATConv, self).__init__(*args, **kwargs)

    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        attn_weight = torch.nn.functional.leaky_relu(alpha_j + alpha_i, negative_slope=0.2)
        alpha = softmax(attn_weight, index, ptr, size_i)
        alpha = torch.nn.functional.dropout(alpha, p=self.dropout, training=self.training)
        out = x_j * alpha.unsqueeze(-1)
        return out

# Define the custom GAT model for node classification
class CustomGAT(torch.nn.Module):
    def __init__(self, num_features, hidden_channels=16):
        super(CustomGAT, self).__init__()
        self.conv1 = GATConv(num_features, hidden_channels, 2, dropout=0.5)
        self.conv2 = GATConv(32, 3, 1, concat=True, dropout=0.5)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        # x = x + data.x
        return x

class TwoLayerGCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(TwoLayerGCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)

        return x

# Convert labels to tensor
def labels_to_tensor(labels):
    return torch.tensor(labels.reshape(-1, labels.shape[-1]), dtype=torch.float)

In [None]:
# GCN Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize the model, loss function, and optimizer
loss_fn_arg = "BCE"
# model = TwoLayerGCN(3, hidden_channels=16)
model = CustomGAT(3, hidden_channels=16).to(device)
loss_fn = torch.nn.MSELoss()
if loss_fn_arg == "BCE":
  loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

In [None]:
# Testing
model.eval()
loss = 0
with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        out = model(data)
        loss += loss_fn(out, data.y).item()

print(f"Node-level error: {loss / len(test_loader)}")

In [None]:
def predict(model, test_graphs):
    model.eval()  # Set the model to evaluation mode
    predictions = []

    with torch.no_grad():  # Disable gradient computation during inference
        for graph in test_graphs:
            # Convert node features and adjacency matrix to tensors if they are not already
            # if not isinstance(node_features, torch.Tensor):
            #     node_features = torch.tensor(node_features, dtype=torch.float32)
            # if not isinstance(adjacency_matrix, torch.Tensor):
            #     adjacency_matrix = torch.tensor(adjacency_matrix, dtype=torch.float32)

            # # Move the tensors to the same device as the model
            # node_features = node_features.to(model.device)
            # adjacency_matrix = adjacency_matrix.to(model.device)

            # Make predictions using the model
            graph = graph.to(device)
            output = model(graph)
            predictions.append(output)

    return predictions

# Use the 'predict' function to get predictions for the test_graphs
test_predictions = predict(model, test_graphs)

In [None]:
width = 148 # only certain dimensions work due to UpSampling (196x196 works, 148x148 works)
height = 148
dim = (width, height) # set the dimensions

plt.figure(figsize=(25,25))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(cv2.cvtColor(test_graphs[i].x.reshape(width,height,3).cpu().detach().numpy(), cv2.COLOR_BGR2RGB))
plt.show()

In [None]:
plt.figure(figsize=(25,25))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(cv2.cvtColor(torch.sigmoid(test_predictions[i].reshape(width,height,3)).cpu().detach().numpy(), cv2.COLOR_BGR2RGB))
plt.show()