In [None]:
# !pip install torch_geometric
# !pip install plotly
# !pip install --upgrade nbformat

In [None]:
import os
import numpy as np
import torch
from torch.nn import Linear, MSELoss, BCELoss, ReLU, Dropout, Sequential
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool,SAGEConv, MessagePassing, GATv2Conv, LayerNorm
from sklearn.model_selection import train_test_split
from scipy import sparse
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (needed for 3D plotting)
import random

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline

In [None]:
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True


# === Dataset Loader ===
class CarGraphDataset(Dataset):
    def __init__(self, root_dir, indices=None):
        self.root_dir = Path(root_dir)
        all_files = sorted([f for f in self.root_dir.rglob("*.npz") if "_adj" not in f.name])
        self.files = [all_files[i] for i in indices] if indices else all_files
        super().__init__()

    def len(self):
        return len(self.files)

    def get(self, idx):
        feat_file = self.files[idx]
        adj_file = feat_file.with_name(feat_file.name.replace(".npz", "_adj.npz"))
        npz_data = np.load(feat_file)

        x = torch.tensor(npz_data["x"], dtype=torch.float32)
        center_point = torch.tensor(npz_data["center_point"], dtype=torch.float32)
        scale = torch.tensor(npz_data["scale"], dtype=torch.float32)

        a = sparse.load_npz(adj_file).tocoo()
        edge_index = torch.tensor(np.vstack((a.row, a.col)), dtype=torch.long)

        return Data(x=x, edge_index=edge_index, center_point=center_point, scale=scale)


In [None]:
#Attention v2 + VAE
class GraphAutoEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super().__init__()

        self.encoder = Sequential(
            GATv2Conv(in_channels, hidden_channels), ReLU(),
            LayerNorm(hidden_channels),
            GATv2Conv(hidden_channels, 2*hidden_channels //4, heads=4, concat=True,residual=True), ReLU(),
            LayerNorm(2*hidden_channels),
            GATv2Conv(2*hidden_channels, 4*hidden_channels //4, heads=4, concat=True, residual=True), ReLU(),
            LayerNorm(4*hidden_channels)
        )
        self.encoder_mu = Linear(4 * hidden_channels, latent_dim)
        self.encoder_logvar = Linear(4 * hidden_channels, latent_dim)
        self.decoder_lin = Sequential(
            Linear(latent_dim, 2*hidden_channels), ReLU(),
            Linear(2*hidden_channels, hidden_channels), ReLU(),
            Linear(hidden_channels, in_channels)
        )

    def encode(self, x, edge_index, batch=None):
        for layer in self.encoder:
            if isinstance(layer, MessagePassing):
                x = layer(x, edge_index)
            else:
                x = layer(x)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        mu = self.encoder_mu(x)
        logvar = self.encoder_logvar(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        z_graph = global_mean_pool(z, batch)
        return mu, logvar, z, z_graph

    def decode(self, z):
        return self.decoder_lin(z)

    def forward(self, x, edge_index, batch=None):
        mu, logvar, z, z_graph = self.encode(x, edge_index, batch)
        x_hat = self.decode(z)
        return mu, logvar, z, x_hat, z_graph


In [None]:
# === Training loop ===
def kl_divergence(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0)

def beta():
    return 1.0  # Consider annealing in future

def run_epoch(model, loader, optimizer, device, train=True):
    model.train() if train else model.eval()
    total_loss = 0

    if len(loader) == 0:
        return 0.0

    for batch in loader:
        batch = batch.to(device)
        x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch

        if train:
            optimizer.zero_grad()

        mu, logvar, z, x_hat, z_graph = model(x, edge_index, batch=batch_idx)

        # Reconstruct adjacency (only for existing edges)
        z_i = z[edge_index[0]]
        z_j = z[edge_index[1]]
        dot_products = (z_i * z_j).sum(dim=1)
        adj_pred = torch.sigmoid(dot_products).clamp(min=1e-7, max=1 - 1e-7)
        adj_true = torch.ones_like(adj_pred)

        # Loss terms
        loss_x = MSELoss()(x_hat, x)
        loss_a = BCELoss()(adj_pred, adj_true)
        loss_kl = kl_divergence(mu, logvar)
        loss = loss_x + loss_a + beta() * loss_kl

        if train:
            loss.backward()
            optimizer.step()

        # Optional logging
        # if train:
        #     print(f"[Train] Feature Loss: {loss_x.item():.6f} | Adjacency Loss: {loss_a.item():.6f} | KL: {loss_kl.item():.6f} | Total: {loss.item():.6f}")

        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
# === Visualize original vs reconstructed geometry ===
def plot_reconstruction(model, dataset, device, sample_idx=None, max_nodes=100):
    import os
    from pathlib import Path

    model.eval()
    idx = random.choice(range(len(dataset))) if sample_idx is None else sample_idx
    data = dataset[idx].to(device)

    with torch.no_grad():
        z, x_hat, _ = model(
            data.x,
            data.edge_index,
            data.batch if hasattr(data, 'batch') else torch.zeros(data.x.size(0), dtype=torch.long, device=device)
        )

    x_orig = data.x.cpu().numpy()
    x_recon = x_hat.cpu().numpy()

    # Sample up to max_nodes
    num_nodes = x_orig.shape[0]
    sampled_indices = np.random.choice(num_nodes, min(max_nodes, num_nodes), replace=False)

    fig, axes = plt.subplots(1, min(3, x_orig.shape[1]), figsize=(15, 5))
    if min(3, x_orig.shape[1]) == 1:
        axes = [axes]

    for i, ax in enumerate(axes):
        ax.scatter(range(len(sampled_indices)), x_orig[sampled_indices, i], label="Original", alpha=0.7)
        ax.scatter(range(len(sampled_indices)), x_recon[sampled_indices, i], label="Reconstructed", alpha=0.7, marker='x')
        ax.set_title(f"Feature {i}")
        ax.legend()

    fig.suptitle(f"Scatter Comparison: Sample #{idx}")
    fig.tight_layout()

    save_path = Path("reconstruction_sample.png")
    print("Saving plot to", save_path.resolve())
    plt.savefig(save_path)
    plt.close()

In [None]:
def plot_latent_space(model, dataset, device):
    model.eval()
    zs = []
    colors = []

    for i, data in enumerate(dataset):
        data = data.to(device)
        batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device)  # if no batch attr
        with torch.no_grad():
            _, _, z_graph = model(data.x, data.edge_index, batch)
        zs.append(z_graph.squeeze(0).cpu().numpy())
        colors.append(i)

    zs = np.array(zs)
    pca = PCA(n_components=2)
    zs_2d = pca.fit_transform(zs)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(zs_2d[:, 0], zs_2d[:, 1], c=colors, cmap='viridis', s=40, edgecolor='k')
    plt.colorbar(scatter, label="Graph index")
    plt.title("2D PCA of Graph-Level Latent Embeddings")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.tight_layout()
    plt.savefig("latent_projection.png")
    plt.close()

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = 'notebook'  # or 'notebook_connected'

def plot_geometry_comparison_interactive(x_orig_np, x_recon_np, sample_idx=None):
    fig = go.Figure()

    fig.add_trace(go.Scatter3d(
        x=x_orig_np[:, 0], y=x_orig_np[:, 1], z=x_orig_np[:, 2],
        mode='markers',
        marker=dict(size=2, color='blue'),
        name='Original'
    ))

    fig.add_trace(go.Scatter3d(
        x=x_recon_np[:, 0], y=x_recon_np[:, 1], z=x_recon_np[:, 2],
        mode='markers',
        marker=dict(size=2, color='red'),
        name='Reconstructed'
    ))

    fig.update_layout(
        title=f"Interactive Geometry Comparison - Sample #{sample_idx}",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        margin=dict(l=0, r=0, b=0, t=30)
    )

    fig.show()

def plot_losses(train_losses, val_losses):
    clear_output(wait=True)
    plt.figure(figsize=(10, 5))

    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Val')

    # Annotate the last values
    if train_losses:
        plt.annotate(f"{train_losses[-1]:.6f}",
                     (len(train_losses) - 1, train_losses[-1]),
                     textcoords="offset points", xytext=(-10, 10),
                     ha='center', color='blue')

    if val_losses:
        plt.annotate(f"{val_losses[-1]:.6f}",
                     (len(val_losses) - 1, val_losses[-1]),
                     textcoords="offset points", xytext=(-10, -15),
                     ha='center', color='orange')

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.title("Training vs Validation Loss")
    plt.tight_layout()
    plt.show()


In [None]:
# === Main execution ===
path = "Graphs"
all_indices = list(range(len(list(Path(path).rglob("*[!_adj].npz")))))
train_idx, valtest_idx = train_test_split(all_indices, test_size=0.3, random_state=42)
val_idx, test_idx = train_test_split(valtest_idx, test_size=0.5, random_state=42)

train_set = CarGraphDataset(path, indices=train_idx)
val_set = CarGraphDataset(path, indices=val_idx)
test_set = CarGraphDataset(path, indices=test_idx)

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4)
test_loader = DataLoader(test_set, batch_size=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GraphAutoEncoder(in_channels=10, hidden_channels=64, latent_dim=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

writer = SummaryWriter(log_dir="runs/graph_ae")
best_val_loss = float("inf")
os.makedirs("checkpoints", exist_ok=True)

train_losses = []
val_losses = []

for epoch in range(1, 101):
    train_loss = run_epoch(model, train_loader, optimizer, device, train=True)
    val_loss = run_epoch(model, val_loader, optimizer, device, train=False)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    writer.add_scalars("Loss", {"train": train_loss, "val": val_loss}, epoch)
    print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
    plot_losses(train_losses, val_losses)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "checkpoints/best_model.pt")

In [None]:
# === Test ===
    model.load_state_dict(torch.load("checkpoints/best_model.pt"))
    test_loss = run_epoch(model, test_loader, optimizer, device, train=False)
    print(f"Final Test Loss: {test_loss:.6f}")

    # === Visualizations ===
    #plot_latent_space(model, test_set, device)
    plot_reconstruction(model, test_set, device, sample_idx=8, max_nodes=100)


In [None]:
# === Visualize full geometry ===
sample_idx = 8 # Or None for random
data = test_set[sample_idx].to(device)

model.eval()
with torch.no_grad():
    _, x_hat, _ = model(data.x, data.edge_index)

    x_orig_np = data.x[:, :3].cpu().numpy()
    x_recon_np = x_hat[:, :3].cpu().numpy()

plot_geometry_comparison_interactive(x_orig_np, x_recon_np, sample_idx)