<a href="https://colab.research.google.com/github/k-kovani/Student_Projects/blob/main/GNN_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import random_split

class GraphDataset(Dataset):
    def __init__(self, input_graphs, target_graphs):
        self.input_graphs = input_graphs
        self.target_graphs = target_graphs

    def len(self):
        return len(self.input_graphs)

    def get(self, idx):
        return self.input_graphs[idx], self.target_graphs[idx]

class GNN(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

    def training_step(self, batch, batch_idx):
        input_graph, target_graph = batch
        output = self(input_graph.x, input_graph.edge_index)
        loss = F.mse_loss(output, target_graph.x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)

class GraphDataModule(pl.LightningDataModule):
    def __init__(self, dataset, batch_size=32):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size

    def setup(self, stage=None):
        train_size = int(0.8 * len(self.dataset))
        val_size = len(self.dataset) - train_size
        self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

# Example data (replace with your actual data)
input_graphs = [Data(x=torch.randn(10, 16), edge_index=torch.randint(0, 10, (2, 20))) for _ in range(100)]
target_graphs = [Data(x=torch.randn(10, 16), edge_index=torch.randint(0, 10, (2, 20))) for _ in range(100)]
dataset = GraphDataset(input_graphs, target_graphs)

# Initialize model, data module, and trainer
model = GNN(input_dim=16, hidden_dim=32, output_dim=16)
data_module = GraphDataModule(dataset)
trainer = pl.Trainer(max_epochs=100)

# Train the model
trainer.fit(model, data_module)