In [20]:
# %%
import torch
import torch.nn.functional as F
import numpy as np
from scipy import sparse
from pathlib import Path
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split


In [21]:

# %%
class CarGraphDataset(Dataset):
    def __init__(self, path, transform=None, pre_transform=None):
        super().__init__(None, transform, pre_transform)
        self.path = Path(path)
        self.files_x = sorted([f for f in self.path.rglob("*.npz") if "_adj" not in f.name])
        self.files_adj = sorted([f for f in self.path.rglob("*.npz") if "_adj" in f.name])

    def len(self):
        return len(self.files_x)

    def get(self, idx):
        x_data = np.load(self.files_x[idx])
        x = torch.tensor(x_data['x'], dtype=torch.float32)
        a = sparse.load_npz(self.files_adj[idx]).tocoo()
        edge_index = torch.tensor(np.vstack((a.row, a.col)), dtype=torch.long)
        return Data(x=x, edge_index=edge_index)

In [22]:

# %%
class GraphAutoEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super().__init__()
        self.encoder_gnn = GCNConv(in_channels, hidden_channels)
        self.encoder_lin = torch.nn.Linear(hidden_channels, latent_dim)
        self.decoder = torch.nn.Linear(latent_dim, in_channels)
        self.dropout = torch.nn.Dropout(p=0.5)

    def encode(self, x, edge_index):
        x = self.encoder_gnn(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        z = self.encoder_lin(x)
        return z

    def decode(self, z_graph, batch):
        return self.decoder(z_graph[batch])

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        z_node = self.encode(x, edge_index)
        z_graph = global_mean_pool(z_node, batch)
        x_hat = self.decode(z_graph, batch)
        return z_node, z_graph, x_hat, edge_index

In [25]:

# %%
class LitGraphAutoEncoder(pl.LightningModule):
    def __init__(self, in_channels, hidden_channels, latent_dim, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = GraphAutoEncoder(in_channels, hidden_channels, latent_dim)

    def forward(self, data):
        return self.model(data)

    def compute_adj_loss(self, z_node, edge_index):
        z_i = z_node[edge_index[0]]
        z_j = z_node[edge_index[1]]
        dot_products = (z_i * z_j).sum(dim=1)
        adj_pred = torch.sigmoid(dot_products)
        adj_true = torch.ones_like(adj_pred)
        return F.binary_cross_entropy(adj_pred, adj_true)

    def training_step(self, batch, batch_idx):
        z_node, _, x_hat, edge_index = self(batch)
        loss_x = F.mse_loss(x_hat, batch.x)
        loss_a = self.compute_adj_loss(z_node, edge_index)
        loss = loss_x + loss_a
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        z_node, _, x_hat, edge_index = self(batch)
        loss_x = F.mse_loss(x_hat, batch.x)
        loss_a = self.compute_adj_loss(z_node, edge_index)
        val_loss = loss_x + loss_a
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
        return val_loss

    def test_step(self, batch, batch_idx):
        z_node, _, x_hat, edge_index = self(batch)
        loss_x = F.mse_loss(x_hat, batch.x)
        loss_a = self.compute_adj_loss(z_node, edge_index)
        test_loss = loss_x + loss_a
        self.log("test_loss", test_loss, prog_bar=True)
        return test_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:

# %%
# ---- TRAINING SECTION ---- #
dataset = CarGraphDataset("/Users/koutsavd/PycharmProjects/Geometry_GNN/Graphs")
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4)

checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    filename="best-graph-ae-{epoch:02d}-{val_loss:.4f}",
    save_weights_only=False,
    verbose=True
)

progress_cb = TQDMProgressBar(refresh_rate=10)
logger = TensorBoardLogger("lightning_logs", name="graph_autoencoder")

trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_cb, progress_cb],
    logger=logger,
    log_every_n_steps=1,
)

model = LitGraphAutoEncoder(in_channels=8, hidden_channels=32, latent_dim=64)
trainer.fit(model, train_loader, val_loader)

In [None]:
import os
import IPython
os._exit(00)