In [3]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Dense, Input, Conv2D, UpSampling2D, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
import cv2
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import random

In [4]:
import torch
import os
print("PyTorch has version {}".format(torch.__version__))
# Install torch geometric
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-geometric
!pip install ogb

PyTorch has version 1.13.1+cu116
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.1%2Bpt113cu116-cp39-cp39-linux_x86_64.whl (9.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1+pt113cu116
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.17%2Bpt113cu116-cp39-cp39-linux_x86_64.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m15

In [54]:
import torch
import numpy as np
from torch_geometric.data import Data

# Process an image into PyG graph.
# Each pixel is a node, and the edges are connections to 4-connectivity neighbors.
def image_to_graph(image):
    h, w = image.shape[:2]
    num_nodes = h * w
    edge_index = []

    for i in range(h):
        for j in range(w):
            node_id = i * w + j
            # Connect to the left pixel
            if j > 0:
              edge_index.append([node_id, node_id - 1])
            # Connect to the right pixel
            if j < w - 1:
              edge_index.append([node_id, node_id + 1])
            # Connect to the up pixel
            if i > 0:
              edge_index.append([node_id, node_id - w])
            # Connect to the bottom pixel
            if i < h - 1:
              edge_index.append([node_id, node_id + w])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(image.reshape(-1, image.shape[-1]), dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

# Process an image into PyG graph.
# Each pixel is a node, and the edges are connections to neighbor pixels defined using `half_filter_dim`.
# e.g `half_filter_dim` = 1, filter size is 3*3. `half_filter_dim` = 2, filter size is 5*5.
def image_to_graph_with_filter(image, half_filter_dim):
    h, w = image.shape[:2]
    num_nodes = h * w
    edge_index = []

    for i in range(h):
        for j in range(w):
            node_id = i * w + j

            for m in range(-half_filter_dim, half_filter_dim + 1):
                for n in range(-half_filter_dim, half_filter_dim + 1):
                    if m == 0 and n == 0:
                        continue

                    neighbor_i = i + m
                    neighbor_j = j + n

                    if (0 <= neighbor_i < h) and (0 <= neighbor_j < w):
                        neighbor_id = neighbor_i * w + neighbor_j
                        edge_index.append([node_id, neighbor_id])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(image.reshape(-1, image.shape[-1]), dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

def plotImg(image):
  plt.figure()
  plt.imshow(image)
  plt.show()

In [55]:
# Example usage:
image = np.random.rand(3, 3, 3)
filter_dim = 1
graph = image_to_graph_with_filter(image, filter_dim)
graph.edge_index

In [59]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class TwoLayerGCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(TwoLayerGCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)

        return x

# Convert labels to PyG graph format
def labels_to_pyg_graph(labels):
    return torch.tensor(labels.reshape(-1, labels.shape[-1]), dtype=torch.float)

# Generate synthetic data for demonstration
def generate_synthetic_data(num_images, img_size):
    images = [np.random.rand(img_size, img_size, 3) for _ in range(num_images)]
    labels = [np.random.rand(img_size, img_size, 3) for _ in range(num_images)]
    return images, labels

In [69]:
from torch_geometric.data import DataLoader
import torch.optim as optim

# Parameters
num_train_images = 100
num_test_images = 20
img_size = 5
filter_dim = 1

# Generate synthetic training and test data
train_images, train_labels = generate_synthetic_data(num_train_images, img_size)
test_images, test_labels = generate_synthetic_data(num_test_images, img_size)

# Convert images and labels to PyG graphs
train_graphs = [image_to_graph_with_filter(img, filter_dim) for img in train_images]
test_graphs = [image_to_graph_with_filter(img, filter_dim) for img in test_images]

train_labels = [labels_to_pyg_graph(lbl) for lbl in train_labels]
test_labels = [labels_to_pyg_graph(lbl) for lbl in test_labels]

# Create DataLoaders
train_loader = DataLoader(list(zip(train_graphs, train_labels)), batch_size=1, shuffle=True)
test_loader = DataLoader(list(zip(test_graphs, test_labels)), batch_size=1, shuffle=False)

# Initialize the model, loss function, and optimizer
model = TwoLayerGCN(3, hidden_channels=16)
loss_fn = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for data, labels in train_loader:
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

# Testing
model.eval()
mse_loss = 0
with torch.no_grad():
    for data, labels in test_loader:
        out = model(data)
        mse_loss += loss_fn(out, labels).item()

print(f"Node-level mean squared error: {mse_loss / len(test_loader)}")

Epoch 1/10, Loss: 0.11046109244227409
Epoch 2/10, Loss: 0.08723216339945793
Epoch 3/10, Loss: 0.08432250633835793
Epoch 4/10, Loss: 0.08340477593243122
Epoch 5/10, Loss: 0.08294453978538513
Epoch 6/10, Loss: 0.08302599519491195
Epoch 7/10, Loss: 0.08278303354978561
Epoch 8/10, Loss: 0.08273020081222057
Epoch 9/10, Loss: 0.08271180391311646
Epoch 10/10, Loss: 0.08269989624619484
Node-level mean squared error: 0.08439605236053467
