In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn import VGAE, NNConv
from torch_geometric.utils import dropout_edge
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import numpy as np
import pandas as pd
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random

def set_seed(seed=42):
    """Fixe toutes les graines aléatoires pour la reproductibilité"""
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # si multi-GPU
    
    # NumPy
    np.random.seed(seed)
    
    # Python random
    random.seed(seed)    
    print(f"Seed fixée à: {seed}")

# Appeler au début
set_seed(42)

Seed fixée à: 42


In [3]:
def load_graphs(graph_dir):
    # Vérification du dossier
    graph_dir = Path(graph_dir)
    if not graph_dir.exists():
        raise FileNotFoundError(
            f"Le dossier '{graph_dir}' n'existe pas."
        )
    if not graph_dir.is_dir():
        raise NotADirectoryError(
            f"'{graph_dir}' n'est pas un dossier valide."
        )
    graphs = []
    # Charger tous les fichiers .pt
    for file_path in sorted(graph_dir.glob("*.pt")):
        try:
            g = torch.load(file_path)
            graphs.append(g)
            print(f"Graphe chargé : {file_path}")
        except Exception as e:
            print(f"Erreur lors du chargement de {file_path} : {e}")
    return graphs

# Charger les graphes
train_graphs_dir = "graphs/train"
test_graphs_dir = "graphs/test"

train_graphs = load_graphs(train_graphs_dir)
test_graphs = load_graphs(test_graphs_dir)

Graphe chargé : graphs/train/graph_batch_1.pt
Graphe chargé : graphs/train/graph_batch_10.pt
Graphe chargé : graphs/train/graph_batch_11.pt
Graphe chargé : graphs/train/graph_batch_12.pt
Graphe chargé : graphs/train/graph_batch_2.pt
Graphe chargé : graphs/train/graph_batch_3.pt
Graphe chargé : graphs/train/graph_batch_4.pt
Graphe chargé : graphs/train/graph_batch_5.pt
Graphe chargé : graphs/train/graph_batch_6.pt
Graphe chargé : graphs/train/graph_batch_7.pt
Graphe chargé : graphs/train/graph_batch_8.pt
Graphe chargé : graphs/train/graph_batch_9.pt
Graphe chargé : graphs/test/graph_batch_1.pt
Graphe chargé : graphs/test/graph_batch_2.pt
Graphe chargé : graphs/test/graph_batch_3.pt
Graphe chargé : graphs/test/graph_batch_4.pt
Graphe chargé : graphs/test/graph_batch_5.pt
Graphe chargé : graphs/test/graph_batch_6.pt
Graphe chargé : graphs/test/graph_batch_7.pt


  g = torch.load(file_path)


## Split temporel (train / val)

In [4]:
num_graphs = len(train_graphs)

train_ratio = 0.8

train_end = int(train_ratio * num_graphs)

val_graphs   = train_graphs[train_end:]
train_graphs = train_graphs[:train_end]


# Créer un DataLoader pour l'entraînement
# batch_size = nombre de graphes dans chaque batch
batch_size = 1

train_loader = DataLoader(
    train_graphs,
    batch_size=batch_size,
    shuffle=False
)

val_loader = DataLoader(
    val_graphs,
    batch_size=batch_size,
    shuffle=False
)

test_loader = DataLoader(
    test_graphs,
    batch_size=batch_size,
    shuffle=False
)

## Normalisation des features

In [5]:
# concaténation de tous les noeuds du TRAIN
node_xs = torch.cat([g.x for g in train_graphs], dim=0)
edge_xs = torch.cat([g.edge_attr for g in train_graphs], dim=0)

# Calcul des statistiques pour les nœuds
node_mean = node_xs.mean(dim=0)
node_std  = node_xs.std(dim=0, unbiased=False) + 1e-6

# Calcul des statistiques pour les arêtes
edge_mean = edge_xs.mean(dim=0)
edge_std  = edge_xs.std(dim=0, unbiased=False) + 1e-6

def normalize_graph(g, node_mean, node_std, edge_mean, edge_std):
    g = g.clone()
    g.x = (g.x - node_mean) / node_std
    g.edge_attr = (g.edge_attr - edge_mean) / edge_std
    return g


# Normalisation de tous les graphes
train_graphs = [normalize_graph(g, node_mean, node_std, edge_mean, edge_std) for g in train_graphs]
val_graphs   = [normalize_graph(g, node_mean, node_std, edge_mean, edge_std) for g in val_graphs]
test_graphs  = [normalize_graph(g, node_mean, node_std, edge_mean, edge_std) for g in test_graphs]

## Définir l’encodeur

In [6]:
class NNConvEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, edge_dim, edge_latent_dim, dropout=0.1):
        super(NNConvEncoder, self).__init__()

        self.dropout = dropout

        self.edge_mlp1 = nn.Sequential(
            nn.Linear(edge_dim, edge_latent_dim),
            nn.ReLU(),
            nn.Linear(edge_latent_dim, in_channels * 2 * out_channels),
        )

        self.edge_mlp2 = nn.Sequential(
            nn.Linear(edge_dim, edge_latent_dim),
            nn.ReLU(),
            nn.Linear(edge_latent_dim, 2 * out_channels * 2 * out_channels),
        )

        self.edge_mlp_mu = nn.Sequential(
            nn.Linear(edge_dim, edge_latent_dim),
            nn.ReLU(),
            nn.Linear(edge_latent_dim, 2 * out_channels * out_channels),
        )

        self.edge_mlp_logvar = nn.Sequential(
            nn.Linear(edge_dim, edge_latent_dim),
            nn.ReLU(),
            nn.Linear(edge_latent_dim, 2 * out_channels * out_channels),
        )

        # première couche NNConv
        self.conv1 = NNConv(
            in_channels=in_channels,
            out_channels=2 * out_channels,
            nn=self.edge_mlp1,
            aggr='mean'
        )

        self.conv2 = NNConv(
            in_channels=2 * out_channels,
            out_channels=2 * out_channels,
            nn=self.edge_mlp2,
            aggr='mean'
        )

        # ===== BatchNorm =====
        self.bn1 = nn.BatchNorm1d(2 * out_channels)
        self.bn2 = nn.BatchNorm1d(2 * out_channels)

        # ===== Couches latentes =====
        self.conv_mu = NNConv(
            in_channels=2 * out_channels,
            out_channels=out_channels,
            nn=self.edge_mlp_mu,
            aggr='mean'
        )

        self.conv_logvar = NNConv(
            in_channels=2 * out_channels,
            out_channels=out_channels,
            nn=self.edge_mlp_logvar,
            aggr='mean'
        )

    def forward(self, x, edge_index, edge_attr):
        # ===== Dropout des arêtes =====
        edge_index, edge_mask = dropout_edge(edge_index, p=0.1, training=self.training)
        edge_attr = edge_attr[edge_mask]


        # ===== NNConv 1 =====
        x = self.conv1(x, edge_index, edge_attr)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # ===== NNConv 2 =====
        x = self.conv2(x, edge_index, edge_attr)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # ===== Latent mu / logvar =====
        mu = self.conv_mu(x, edge_index, edge_attr)
        logvar = self.conv_logvar(x, edge_index, edge_attr)

        return mu, logvar
    

## Créer le modèle VGAE

In [7]:
# Dimensions extraites depuis UN graphe (tous ont la même structure)
in_channels = train_graphs[0].x.shape[1]        # features par nœud
out_channels = 16                               # dimension latente
edge_dim = train_graphs[0].edge_attr.shape[1]   # features par arête
edge_latent_dim = 8                             # latent des arêtes

encoder = NNConvEncoder(in_channels, out_channels, edge_dim, edge_latent_dim)
model = VGAE(encoder)

## Définir l’optimiseur

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr= 5e-4, weight_decay=1e-4)

## Détecter le device (GPU si dispo)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# device = torch.device('cpu')

# Déplacer le modèle
model = model.to(device)

Using device: cuda


## Early Stopping

In [9]:
patience = 50  # nombre d'époques sans amélioration
best_loss = float('inf')
counter = 0

## Entraînement du modèle

In [179]:
num_epochs = 1000
kl_anneal_epochs = 100  # montée progressive de β


for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0.0

    beta = min(1.0, epoch / kl_anneal_epochs)
    
    # ===== ENTRAÎNEMENT SUR LE PASSÉ =====
    for t, data in enumerate(train_loader):
        assert data.edge_index.dtype == torch.long
        assert data.edge_index.min() >= 0
        assert data.edge_index.max() < data.num_nodes
        # Chaque data = 1 graphe temporel
        data = data.to(device)

        optimizer.zero_grad()
        # Encodage VGAE (NNConv utilise edge_attr)
        z = model.encode(
            data.x,
            data.edge_index,
            data.edge_attr
        )
        # Loss VGAE = reconstruction + KL
        recon_loss = model.recon_loss(z, data.edge_index) / data.num_nodes  # normalisation
        kl_loss = model.kl_loss() / data.num_nodes  # normalisation
        loss = recon_loss + beta * kl_loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # ===== VALIDATION SUR LE FUTUR (SANS BACKPROP) =====
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)

            z = model.encode(
                data.x,
                data.edge_index,
                data.edge_attr
            )

            recon_loss = model.recon_loss(z, data.edge_index) / data.num_nodes  # normalisation
            kl_loss = model.kl_loss() / data.num_nodes

            val_loss += (recon_loss + kl_loss).item()

    print(
        f"Epoch {epoch:03d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f}"
    )

    # ===== EARLY STOPPING TEMPOREL =====
    if val_loss < best_loss:
        best_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), "best_vgae_nnconv.pt")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch}")
            break

Epoch 001 | Train Loss: 21.7973 | Val Loss: 837.8252
Epoch 002 | Train Loss: 37.8543 | Val Loss: 485.2616
Epoch 003 | Train Loss: 38.9671 | Val Loss: 280.2459
Epoch 004 | Train Loss: 36.5308 | Val Loss: 207.6094
Epoch 005 | Train Loss: 44.9456 | Val Loss: 188.2070
Epoch 006 | Train Loss: 26.7884 | Val Loss: 110.3214
Epoch 007 | Train Loss: 28.5280 | Val Loss: 79.5940
Epoch 008 | Train Loss: 23.9417 | Val Loss: 56.6817
Epoch 009 | Train Loss: 22.1403 | Val Loss: 46.8785
Epoch 010 | Train Loss: 19.2110 | Val Loss: 36.1800
Epoch 011 | Train Loss: 14.1380 | Val Loss: 28.2002
Epoch 012 | Train Loss: 20.7256 | Val Loss: 25.5023
Epoch 013 | Train Loss: 17.2331 | Val Loss: 19.8589
Epoch 014 | Train Loss: 11.4548 | Val Loss: 15.1961
Epoch 015 | Train Loss: 23.2446 | Val Loss: 12.5397
Epoch 016 | Train Loss: 11.7561 | Val Loss: 10.0946
Epoch 017 | Train Loss: 9.8955 | Val Loss: 7.6002
Epoch 018 | Train Loss: 10.6506 | Val Loss: 6.1832
Epoch 019 | Train Loss: 7.8236 | Val Loss: 5.4311
Epoch 020 |

## Charger le modèle

In [10]:
# Charger le meilleur modèle
model.load_state_dict(torch.load("best_vgae_nnconv.pt"))


  model.load_state_dict(torch.load("best_vgae_nnconv.pt"))


<All keys matched successfully>

## Extraire les embeddings

In [11]:
train_rows = []

model.eval()
with torch.no_grad():
    for data in train_loader:
        data = data.to(device)
        z = model.encode(data.x, data.edge_index, data.edge_attr).cpu().numpy()
        tx_id = data.tx_id.cpu().numpy()
        y = data.y.cpu().numpy()
        for i in range(len(tx_id)):
            row = {"transaction_id": tx_id[i], "label": y[i]}
            for j in range(z.shape[1]):
                row[f"z_{j}"] = z[i, j]
            train_rows.append(row)

df_train_emb = pd.DataFrame(train_rows)
df_train_emb.to_csv("train_embeddings.csv", index=False)


In [12]:
test_rows = []

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        z = model.encode(data.x, data.edge_index, data.edge_attr).cpu().numpy()
        tx_id = data.tx_id.cpu().numpy()
        y = data.y.cpu().numpy()
        for i in range(len(tx_id)):
            row = {"transaction_id": tx_id[i], "label": y[i]}
            for j in range(z.shape[1]):
                row[f"z_{j}"] = z[i, j]
            test_rows.append(row)

df_test_emb = pd.DataFrame(test_rows)
df_test_emb.to_csv("test_embeddings.csv", index=False)
