In [9]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
import pickle

# Définir la classe GCN (même que dans l'entraînement)
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Charger le modèle et les données
model = torch.load('gcn_recommender.pth', weights_only=False)
with open('evaluation_data.pkl', 'rb') as f:
    data = pickle.load(f)
    graph_data = data['graph_data']
    test_idx = data['test_idx']
    y_test = data['y_test']

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
model.eval()
with torch.no_grad():
    node_embeddings = model(graph_data.x, graph_data.edge_index)
    y_pred_test = node_embeddings[graph_data.edge_index[1]][test_idx]


In [11]:
# Convertir une seule fois pour optimiser les performances
y_test_np = y_test.cpu().numpy()
y_pred_test_np = y_pred_test.cpu().numpy()

rmse = np.sqrt(mean_squared_error(y_test_np, y_pred_test_np))

In [12]:
mae = mean_absolute_error(y_test_np, y_pred_test_np)

In [13]:
print("RMSE :", rmse)
print("MAE  :", mae)


RMSE : 2.6488353689207464
MAE  : 2.636859178543091


In [14]:
# Vrai labels (ground truth) - Réutiliser les numpy arrays déjà convertis
y_true_cls = (y_test_np >= 0).astype(int)

# Prédictions binaires
y_pred_cls = (y_pred_test_np >= 0).astype(int)

In [15]:
precision = precision_score(y_true_cls, y_pred_cls)


In [16]:
recall = recall_score(y_true_cls, y_pred_cls)


In [17]:
f1 = f1_score(y_true_cls, y_pred_cls)


In [18]:
print("Precision :", precision)
print("Recall    :", recall)
print("F1-score  :", f1)


Precision : 0.49768211920529803
Recall    : 1.0
F1-score  : 0.6646031395091754


In [19]:
results = {
    "RMSE": rmse,
    "MAE": mae,
    "Precision": precision,
    "Recall": recall,
    "F1-score": f1
}

for k, v in results.items():
    print(f"{k}: {v:.4f}")


RMSE: 2.6488
MAE: 2.6369
Precision: 0.4977
Recall: 1.0000
F1-score: 0.6646
