In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv, global_add_pool, BatchNorm
from torch_geometric.loader import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("✔ Using device:", device)

✔ Using device: cuda


# 📥 Chargement des données

In [2]:
dataset = torch.load("./../data/pyg_graphs.pt")
num_classes = len(set(int(data.y) for data in dataset))
print("✔ Number of classes:", num_classes)

  dataset = torch.load("./../data/pyg_graphs.pt")


✔ Number of classes: 104


# 🔀 Split

In [None]:
torch.manual_seed(42)
n = len(dataset)
train_dataset = dataset[:int(0.7 * n)]
val_dataset   = dataset[int(0.7 * n):int(0.85 * n)]
test_dataset  = dataset[int(0.85 * n):]

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=256)
test_loader  = DataLoader(test_dataset, batch_size=256)

# 🧠 GCN Model

In [4]:
class GCNWithEmbeddings(nn.Module):
    def __init__(self,
                 num_node_types: int,
                 num_spellings: int,
                 hidden_dim: int = 128,
                 num_classes: int = 104):
        super().__init__()

        # Embeddings
        self.type_embedding = nn.Embedding(num_node_types, hidden_dim)
        self.spelling_embedding = nn.Embedding(num_spellings, hidden_dim)

        # GCN stack
        self.convs = nn.ModuleList([
            GCNConv(hidden_dim * 2 + 2, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
        ])
        self.norms = nn.ModuleList([
            BatchNorm(hidden_dim) for _ in range(4)
        ])

        self.dropout = nn.Dropout(0.4)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, batch):
        # x: [N, 4] = [type_id, spelling_id, is_op, is_lit]
        type_id = x[:, 0]
        spelling_id = x[:, 1]
        flags = x[:, 2:].float()  # [is_operator, is_literal]

        type_vec = self.type_embedding(type_id)
        spell_vec = self.spelling_embedding(spelling_id)

        node_vec = torch.cat([type_vec, spell_vec, flags], dim=1)

        for conv, norm in zip(self.convs, self.norms):
            node_vec = conv(node_vec, edge_index)
            node_vec = norm(node_vec)
            node_vec = F.relu(node_vec)
            node_vec = self.dropout(node_vec)

        graph_vec = global_add_pool(node_vec, batch)
        x = F.relu(self.fc1(graph_vec))
        x = self.dropout(x)
        return self.fc2(x)

num_node_types = max(int(data.x[:, 0].max()) for data in dataset) + 1
num_spellings  = max(int(data.x[:, 1].max()) for data in dataset) + 1
model = GCNWithEmbeddings(num_node_types=num_node_types, num_spellings=num_spellings, hidden_dim=128, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 🔁 Training loop

In [5]:
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        assert batch.y.max() < num_classes, f"Invalid y: {batch.y.max()} ≥ {num_classes}"
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# 🧪 Évaluation

In [6]:
def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())
        total += batch.num_graphs
    return correct / total

# 🕹️ Early stopping params

In [7]:
best_val_acc = 0.0
patience = 100
patience_counter = 0
save_path = "model_best.pt"

# 🚀 Training with early stopping

In [8]:
for epoch in range(1, 10000 + 1):
    loss = train()
    val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), save_path)
        print("✅ Model improved and saved.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("⏹️ Early stopping triggered.")
            break

Epoch 001 | Loss: 4.8588 | Val Acc: 0.0000
Epoch 002 | Loss: 4.3238 | Val Acc: 0.0000
Epoch 003 | Loss: 4.3074 | Val Acc: 0.0000
Epoch 004 | Loss: 4.3016 | Val Acc: 0.0000
Epoch 005 | Loss: 4.2987 | Val Acc: 0.0000
Epoch 006 | Loss: 4.2968 | Val Acc: 0.0000
Epoch 007 | Loss: 4.2961 | Val Acc: 0.0000
Epoch 008 | Loss: 4.2948 | Val Acc: 0.0000
Epoch 009 | Loss: 4.2948 | Val Acc: 0.0000
Epoch 010 | Loss: 4.2943 | Val Acc: 0.0000
Epoch 011 | Loss: 4.2940 | Val Acc: 0.0000
Epoch 012 | Loss: 4.2937 | Val Acc: 0.0000
Epoch 013 | Loss: 4.2936 | Val Acc: 0.0000
Epoch 014 | Loss: 4.2935 | Val Acc: 0.0000


KeyboardInterrupt: 

# 🔍 Final evaluation

In [None]:
model.load_state_dict(torch.load(save_path))
test_acc = evaluate(test_loader)
print(f"🏁 Final Test Accuracy: {test_acc:.4f}")

  model.load_state_dict(torch.load(save_path))


🏁 Final Test Accuracy: 0.0000


In [11]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

dataset = torch.load("../data/pyg_graphs.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Moyenne des features de chaque graphe
X = torch.stack([data.x.float().mean(dim=0) for data in dataset])  # [N, 4]
y = torch.tensor([int(data.y) for data in dataset])                # [N]

# Split
n = len(X)
train_X, val_X = X[:int(0.8*n)], X[int(0.8*n):]
train_y, val_y = y[:int(0.8*n)], y[int(0.8*n):]

# Mini modèle MLP
model = nn.Sequential(
    nn.Linear(4, 128),
    nn.ReLU(),
    nn.Linear(128, 104)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Entraînement rapide
for epoch in range(10):
    model.train()
    optimizer.zero_grad()
    out = model(train_X.to(device))
    loss = criterion(out, train_y.to(device))
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()
    with torch.no_grad():
        pred = model(val_X.to(device)).argmax(dim=1)
        acc = (pred.cpu() == val_y).float().mean()
        print(f"Epoch {epoch+1:02d} | Val Acc: {acc:.4f}")


  dataset = torch.load("../data/pyg_graphs.pt")


Epoch 01 | Val Acc: 0.0000
Epoch 02 | Val Acc: 0.0000
Epoch 03 | Val Acc: 0.0000
Epoch 04 | Val Acc: 0.0000
Epoch 05 | Val Acc: 0.0000
Epoch 06 | Val Acc: 0.0000
Epoch 07 | Val Acc: 0.0000
Epoch 08 | Val Acc: 0.0000
Epoch 09 | Val Acc: 0.0000
Epoch 10 | Val Acc: 0.0000
