In [43]:
import copy
import hashlib
import json
import os
import random
from typing import Set
from typing import Tuple, List, Dict, Any

import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import StratifiedKFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, BatchNorm
from tqdm import tqdm
from xgboost import XGBClassifier


def warn(*args, **kwargs):
    pass


import warnings

warnings.warn = warn


# Config

In [44]:
NUM_FOLDS = 5
print(f"NUM_FOLDS set to {NUM_FOLDS}")

NUM_EPOCHS = 15
print(f"NUM_EPOCHS set to {NUM_EPOCHS}")

HIDDEN_SIZE = 128
print(f"HIDDEN_SIZE set to {HIDDEN_SIZE}")

BATCH_SIZE = 32
print(f"BATCH_SIZE set to {BATCH_SIZE}")

LR = 0.01
print(f"LR set to {LR}")

NUM_AUGMENTATIONS = 5
print(f"NUM_AUGMENTATIONS set to {NUM_AUGMENTATIONS}")

TEST_SIZE = 0.1
print(f"TEST_SIZE set to {TEST_SIZE}")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"DEVICE set to {DEVICE}")

# Setting random seeds for reproducibility
RANDOM_SEED = 0
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.backends.cudnn.benchmark = False
print(f"Random seeds set to {RANDOM_SEED}")

# Creating the log directory if it doesn't exist
LOG_DIR = os.path.join("log")
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
    print(f"Log directory created at {LOG_DIR}")
else:
    print(f"Log directory already exists at {LOG_DIR}")

NUM_FOLDS set to 5
NUM_EPOCHS set to 15
HIDDEN_SIZE set to 128
BATCH_SIZE set to 32
LR set to 0.01
NUM_AUGMENTATIONS set to 5
TEST_SIZE set to 0.1
DEVICE set to cpu
Random seeds set to 0
Log directory already exists at log


# Dataset and Augmentation

In [45]:
class GraphDataset(Dataset):
    def __init__(self, graphs: List[Data]):
        self.graphs = graphs

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

In [46]:
class ASTAugmentor:
    def __init__(self):
        """Initialize the ASTAugmentor with predefined strategies."""
        self.substitutions = {'FunctionDefinition': 'ModifierDefinition'}
        self.insertions = {
            'body': {'nodeType': 'ExpressionStatement', 'expression': {'nodeType': 'Literal', 'value': '0'}}}
        self.deletions = {'ModifierDefinition'}
        self.renames = {'oldVarName': 'newVarName', 'oldFuncName': 'newFuncName'}

    def substitute_nodes(self, ast: Any, substitutions: Dict[str, str]) -> Any:
        """Substitute certain nodes in the AST with other semantically equivalent nodes."""
        if isinstance(ast, dict):
            for key, value in ast.items():
                if key in substitutions:
                    ast[key] = substitutions[key]
                else:
                    ast[key] = self.substitute_nodes(value, substitutions)
        elif isinstance(ast, list):
            for i in range(len(ast)):
                ast[i] = self.substitute_nodes(ast[i], substitutions)
        return ast

    def insert_nodes(self, ast: Any, insertions: Dict[str, Any]) -> Any:
        """Insert certain nodes into the AST."""
        if isinstance(ast, dict):
            for key, value in ast.items():
                if key in insertions:
                    ast[key] = [value, insertions[key]] if isinstance(value, list) else [value, insertions[key]]
                else:
                    ast[key] = self.insert_nodes(value, insertions)
        elif isinstance(ast, list):
            for i in range(len(ast)):
                ast[i] = self.insert_nodes(ast[i], insertions)
        return ast

    def delete_nodes(self, ast: Any, deletions: Set[str]) -> Any:
        """Delete certain nodes from the AST."""
        if isinstance(ast, dict):
            keys_to_delete = [key for key in ast if key in deletions]
            for key in keys_to_delete:
                del ast[key]
            for key, value in ast.items():
                ast[key] = self.delete_nodes(value, deletions)
        elif isinstance(ast, list):
            ast = [self.delete_nodes(item, deletions) for item in ast if item not in deletions]
        return ast

    def rename_identifiers(self, ast: Any, renames: Dict[str, str]) -> Any:
        """Rename variables/functions in the AST."""
        if isinstance(ast, dict):
            for key, value in ast.items():
                if key == 'name' and value in renames:
                    ast[key] = renames[value]
                else:
                    ast[key] = self.rename_identifiers(value, renames)
        elif isinstance(ast, list):
            for i in range(len(ast)):
                ast[i] = self.rename_identifiers(ast[i], renames)
        return ast

    def reorder_statements(self, ast: Any) -> Any:
        """Randomly reorder statements in the AST."""
        if isinstance(ast, dict) and 'body' in ast:
            if isinstance(ast['body'], list):
                random.shuffle(ast['body'])
            else:
                self.reorder_statements(ast['body'])
        elif isinstance(ast, list):
            for item in ast:
                self.reorder_statements(item)
        return ast

    def add_no_op_statements(self, ast: Any) -> Any:
        """Add no-op statements to the AST."""
        no_op_statement = {'nodeType': 'ExpressionStatement', 'expression': {'nodeType': 'Literal', 'value': '0'}}
        if isinstance(ast, dict) and 'body' in ast:
            if isinstance(ast['body'], list):
                ast['body'].append(no_op_statement)
            else:
                self.add_no_op_statements(ast['body'])
        elif isinstance(ast, list):
            for item in ast:
                self.add_no_op_statements(item)
        return ast

    def apply_augmentation(self, ast: Any) -> Any:
        """Apply random augmentations to the AST."""
        if random.random() > 0.5:
            ast = self.substitute_nodes(ast, self.substitutions)
        if random.random() > 0.5:
            ast = self.insert_nodes(ast, self.insertions)
        if random.random() > 0.5:
            ast = self.delete_nodes(ast, self.deletions)
        if random.random() > 0.5:
            ast = self.rename_identifiers(ast, self.renames)
        if random.random() > 0.5:
            ast = self.reorder_statements(ast)
        if random.random() > 0.5:
            ast = self.add_no_op_statements(ast)
        return ast

    def generate_augmented_asts(self, dataset: List[Any], num_augmentations: int = 5) -> List[Any]:
        """Generate augmented ASTs for each AST in the dataset."""
        augmented_dataset = []
        for ast in tqdm(dataset, desc="Generating augmented ASTs"):
            augmented_dataset.append(ast)
            for _ in range(num_augmentations):
                augmented_ast = self.apply_augmentation(copy.deepcopy(ast))
                augmented_dataset.append(augmented_ast)
        return augmented_dataset

In [47]:
class AugmentedGraphDataset(GraphDataset):
    def __init__(self, graphs: List[Any], num_augmentations: int = 5):
        super().__init__(graphs)
        self.augmentor = ASTAugmentor()
        self.num_augmentations = num_augmentations
        self.augmented_graphs = self._augment_graphs()

    def _augment_graphs(self) -> List[Any]:
        return self.augmentor.generate_augmented_asts(self.graphs, self.num_augmentations)

    def __len__(self):
        return len(self.augmented_graphs)

    def __getitem__(self, idx):
        return self.augmented_graphs[idx]

# Feature Extraction

In [48]:
class FeatureExtractor:
    @staticmethod
    def hash_feature(value: Any, num_bins: int = 1000) -> int:
        """Helper function to hash a value into a fixed number of bins."""
        return int(hashlib.md5(str(value).encode()).hexdigest(), 16) % num_bins

    @staticmethod
    def extract_features(node: Dict) -> List[int]:
        """Extract features from a given AST node."""
        # Initialize features with default values
        name_feature, value_feature = [0], [0]
        src_feature, type_desc_features = [0, 0], [0, 0]
        state_mutability_feature, visibility_feature = [0], [0]

        # Extract basic features
        node_type = node.get('nodeType', 'Unknown')
        type_feature = [FeatureExtractor.hash_feature(node_type)]

        # Extract additional features if they exist
        if 'name' in node:
            name_feature = [FeatureExtractor.hash_feature(node.get('name', ''))]
        if 'value' in node:
            value_feature = [FeatureExtractor.hash_feature(node.get('value', ''))]

        # Extract src features (start, end, and length if available)
        if 'src' in node:
            start, length, *_ = map(int, node['src'].split(':'))
            src_feature = [start, length]

        # Extract typeDescriptions features if they exist
        if 'typeDescriptions' in node:
            type_desc = node['typeDescriptions']
            type_desc_features = [FeatureExtractor.hash_feature(type_desc.get('typeString', '')),
                                  FeatureExtractor.hash_feature(type_desc.get('typeIdentifier', ''))]

        # Extract stateMutability if it exists
        if 'stateMutability' in node:
            state_mutability_feature = [FeatureExtractor.hash_feature(node.get('stateMutability', ''))]

        # Extract visibility if it exists
        if 'visibility' in node:
            visibility_feature = [FeatureExtractor.hash_feature(node.get('visibility', ''))]

        # Combine all features into a single feature vector
        return (type_feature + name_feature + value_feature + src_feature +
                type_desc_features + state_mutability_feature + visibility_feature)


# Data Loading

In [49]:
class ASTGraphConverter:
    @staticmethod
    def add_masks_to_data(data: Data) -> Data:
        """Add train, validation, and test masks to the data."""
        num_nodes = data.x.size(0)
        indices = torch.randperm(num_nodes)

        train_size = int(num_nodes * (1 - TEST_SIZE))
        val_size = int(num_nodes * TEST_SIZE)

        train_mask = torch.zeros(num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(num_nodes, dtype=torch.bool)

        train_mask[indices[:train_size]] = True
        val_mask[indices[train_size:train_size + val_size]] = True
        test_mask[indices[train_size + val_size:]] = True

        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        return data

    def ast_to_graph(self, ast_json: Dict) -> Data:
        """Convert an AST JSON object to a PyTorch Geometric Data object."""
        graph = nx.DiGraph()
        node_id = 0

        def add_nodes_edges(node: Dict, parent: int = None):
            nonlocal node_id
            current_node_id = node_id
            graph.add_node(current_node_id, features=FeatureExtractor.extract_features(node))
            if parent is not None:
                graph.add_edge(parent, current_node_id)
            node_id += 1
            for key, value in node.items():
                if isinstance(value, dict):
                    add_nodes_edges(value, current_node_id)
                elif isinstance(value, list):
                    for item in value:
                        if isinstance(item, dict):
                            add_nodes_edges(item, current_node_id)

        add_nodes_edges(ast_json)
        edge_index = torch.tensor(list(graph.edges)).t().contiguous()
        x = torch.stack([torch.tensor(graph.nodes[n]['features'], dtype=torch.float) for n in graph.nodes])

        data = Data(x=x, edge_index=edge_index)
        data = self.add_masks_to_data(data)
        return data


class DataFetcher:
    def __init__(self, data_dir: str):
        """Initialize the DataLoader with the AST directory, label map, and graph converter."""
        self.data_dir = data_dir
        self.label_map = self.generate_label_map()
        self.graph_converter = ASTGraphConverter()

    def generate_label_map(self) -> Dict[str, int]:
        """Generate a label map from the directory structure."""
        print(f"Generating label map from directory: {self.data_dir}")
        label_map = {}
        label_index = 0
        for category in os.listdir(self.data_dir):
            category_path = os.path.join(self.data_dir, category)
            if os.path.isdir(category_path):
                label_map[category] = label_index
                label_index += 1
        print(f"Label map generated with {len(label_map)} labels")
        return label_map

    def load_data(self) -> Tuple[List[Data], List[int]]:
        """Load data from the AST directory and return a dataset and labels."""
        print(f"Loading data from directory: {ast_directory}")
        dataset = []
        for category in tqdm(os.listdir(self.data_dir), desc="Processing categories"):
            category_path = os.path.join(self.data_dir, category)
            if os.path.isdir(category_path):
                for root, _, files in os.walk(category_path):
                    for file in files:
                        if file.endswith('.json'):
                            filepath = os.path.join(root, file)
                            with open(filepath, 'r') as f:
                                ast = json.load(f)
                            data = self.graph_converter.ast_to_graph(ast)
                            label = self.label_map[category]
                            data.y = torch.tensor([label] * data.x.size(0),
                                                  dtype=torch.long)  # Assign label to all nodes
                            dataset.append(data)
        print(f"Loaded {len(dataset)} samples from {self.data_dir}")
        return dataset

In [50]:
# Set the directory for the AST data
ast_directory = os.path.join("..", "dataset", "aisc", "ast")

# Load the data using the data loader
graphs = DataFetcher(ast_directory).load_data()

# Check if any data was loaded
if len(graphs) == 0:
    print("No data loaded. Please check the dataset directory and files.")
else:
    print(f"Data loaded successfully with {len(graphs)} samples.")

    # Create a custom dataset
    DATASET = GraphDataset(graphs)
    AUGMENTED_DATASET = AugmentedGraphDataset(graphs)

    # Initialize NUM_LABELS and NUM_FEATURES
    all_labels = torch.cat([data.y for data in graphs], dim=0)
    NUM_LABELS = len(torch.unique(all_labels))
    NUM_FEATURES = graphs[0].x.shape[1]  # Assuming all graphs have the same number of features
    print(f"Number of labels: {NUM_LABELS}")
    print(f"Number of features: {NUM_FEATURES}")


Generating label map from directory: ../dataset/aisc/ast
Label map generated with 11 labels
Loading data from directory: ../dataset/aisc/ast


Processing categories: 100%|██████████| 11/11 [00:03<00:00,  2.83it/s]


Loaded 2040 samples from ../dataset/aisc/ast
Data loaded successfully with 2040 samples.


Generating augmented ASTs: 100%|██████████| 2040/2040 [00:01<00:00, 1978.23it/s]

Number of labels: 11
Number of features: 9





# Model Training and Cross-validation

In [51]:
def compute_metrics(true_labels: List[Any], pred_labels: List[Any]) -> Dict[str, float]:
    """
    Compute evaluation metrics for the given true and predicted labels.

    :param true_labels: The ground truth labels.
    :param pred_labels: The predicted labels.
    :return: A dictionary containing precision, recall, F1 score, and accuracy.
    """
    return {
        "precision": precision_score(true_labels, pred_labels, average='weighted', zero_division=0),
        "recall": recall_score(true_labels, pred_labels, average='weighted', zero_division=0),
        "f1": f1_score(true_labels, pred_labels, average='weighted', zero_division=0),
        "accuracy": accuracy_score(true_labels, pred_labels)
    }


def save_results(results: List[Dict[str, Any]], filename: str) -> None:
    """
    Save the results to a CSV file.

    :param results: The results to save, typically a list of dictionaries.
    :param filename: The name of the file to save the results to.
    """
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(LOG_DIR, filename), index=False)
    print(f"All fold results saved to '{os.path.join(LOG_DIR, filename)}'")

In [52]:
class Trainer:
    """
    Trainer class for handling the training and evaluation of a model.
    """

    def __init__(self, model: torch.nn.Module):
        """
        Initialize the trainer with model, loss criterion, and optimizer.

        :param model: The neural network model to be trained.
        """
        self.__untrained_model = model
        self._model = model.to(DEVICE)
        self._optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        self._loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)

    def reset_model(self):
        """
        Reset the model to its initial untrained state.
        """
        self._model = self.__untrained_model
        self._optimizer = torch.optim.Adam(self.__untrained_model.parameters(), lr=LR)

    def _evaluate_batch(self, batch: Data) -> Tuple[float, Dict[str, float]]:
        """
        Evaluate a single batch of data.

        :param batch: A Data object containing the batch of graphs.
        :return: A tuple containing the loss and a dictionary of metrics.
        """
        # Move batch to the appropriate device (CPU/GPU)
        batch = batch.to(DEVICE)

        # Prepare the inputs for the model
        inputs, labels = batch.x, batch.y

        # Disable gradient computation for evaluation
        with torch.no_grad():
            outputs = self._model(batch)

            # Compute the loss
            loss = self._loss_fn(outputs, labels)

            # Make predictions and compute batch metrics
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            batch_metrics = compute_metrics(labels.cpu().numpy(), predictions)

        # Return the loss and metrics
        return loss.item(), batch_metrics

    def _train_batch(self, batch: Data) -> Tuple[float, Dict[str, float]]:
        """
        Train a single batch of data.

        :param batch: A Data object containing the batch of graphs.
        :return: A tuple containing the loss and a dictionary of metrics.
        """
        # Move batch to the appropriate device (CPU/GPU)
        batch = batch.to(DEVICE)

        # Prepare inputs for the model
        inputs, labels = batch.x, batch.y

        # Zero the parameter gradients
        self._model.zero_grad()

        # Forward pass
        outputs = self._model(batch)

        # Compute the loss
        loss = self._loss_fn(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        self._optimizer.step()

        # Make predictions and compute metrics
        predictions = torch.argmax(outputs, dim=1).detach().cpu().numpy()
        batch_metrics = compute_metrics(labels.detach().cpu().numpy(), predictions)

        return loss.item(), batch_metrics

    def run_epoch(self, dataloader: DataLoader, train_mode: bool = True) -> Tuple[float, Dict[str, float]]:
        """
        Run a single epoch of training or evaluation.

        :param dataloader: DataLoader providing the data for the epoch.
        :param train_mode: Boolean flag indicating whether to train or evaluate.
        :return: A tuple containing the average loss and a dictionary of average metrics.
        """
        # Set the mode for the epoch (Training or Testing)
        phase = 'Training' if train_mode else 'Testing'
        self._model.train() if train_mode else self._model.eval()

        losses, metrics_list = [], []

        # Iterate over the data loader
        for batch in tqdm(dataloader, desc=phase):
            loss, batch_metrics = self._train_batch(batch) if train_mode else self._evaluate_batch(batch)

            # Accumulate the loss and metrics
            losses.append(loss)
            metrics_list.append(batch_metrics)

        # Compute average loss and metrics for the epoch
        avg_loss = np.mean(losses)
        avg_metrics = {metric: np.mean([m[metric] for m in metrics_list]) for metric in metrics_list[0]}

        return avg_loss, avg_metrics


In [53]:
class CrossValidator:
    """
    CrossValidator class for handling stratified k-fold cross-validation of a model.
    """

    def __init__(self, trainer: Trainer, train_data: Dataset, test_data: Dataset):
        """
        Initialize the CrossValidator with trainer, training data, and test data.

        :param trainer: An instance of the Trainer class.
        :param train_data: The training dataset.
        :param test_data: The test dataset.
        """
        self.__trainer = trainer
        self.__train_data = train_data
        self.__test_data = test_data

    def __train_and_evaluate(self, train_dataloader: DataLoader, test_dataloader: DataLoader) -> None:
        """
        Train and evaluate the model for a specified number of epochs.

        :param train_dataloader: DataLoader for the training data.
        :param test_dataloader: DataLoader for the validation data.
        """
        for epoch in range(NUM_EPOCHS):
            print(f"\n --- Epoch {epoch + 1}/{NUM_EPOCHS} ---")

            # Train the model and print training metrics
            avg_train_loss, avg_train_metrics = self.__trainer.run_epoch(train_dataloader, train_mode=True)
            print(f"\n TRAIN | Loss: {avg_train_loss:.4f} |"
                  f" Accuracy: {avg_train_metrics['accuracy']:.4f},"
                  f" Precision: {avg_train_metrics['precision']:.4f},"
                  f" Recall: {avg_train_metrics['recall']:.4f},"
                  f" F1: {avg_train_metrics['f1']:.4f}\n")

            # Evaluate the model on the validation set and print validation metrics
            avg_test_loss, avg_test_metrics = self.__trainer.run_epoch(test_dataloader, train_mode=False)
            print(f" VALID | Loss: {avg_test_loss:.4f} |"
                  f" Accuracy: {avg_test_metrics['accuracy']:.4f},"
                  f" Precision: {avg_test_metrics['precision']:.4f},"
                  f" Recall: {avg_test_metrics['recall']:.4f},"
                  f" F1: {avg_test_metrics['f1']:.4f}\n")

    def __evaluate_on_test_set(self, test_dataloader: DataLoader) -> Dict[str, float]:
        """
        Evaluate the model on the test set.

        :param test_dataloader: DataLoader for the test data.
        :return: A dictionary of test set metrics.
        """
        avg_test_loss, avg_test_metrics = self.__trainer.run_epoch(test_dataloader, train_mode=False)

        # Print test set metrics
        print(f"\nTest Set Evaluation | Loss: {avg_test_loss:.4f} |"
              f" Accuracy: {avg_test_metrics['accuracy']:.4f},"
              f" Precision: {avg_test_metrics['precision']:.4f},"
              f" Recall: {avg_test_metrics['recall']:.4f},"
              f" F1: {avg_test_metrics['f1']:.4f}\n")

        return avg_test_metrics

    def k_fold_cv(self, log_id: str = "gcn") -> None:
        """
        Perform k-fold cross-validation.

        :param log_id: Identifier for logging purposes, typically the model name.
        """
        skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True)
        fold_metrics = []

        # Extract labels for stratification
        labels = [data.y[0].item() for data in self.__train_data]

        # Iterate over each fold
        for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
            # Create data loaders for training and validation sets
            train_subset = Subset(self.__train_data, train_idx)
            val_subset = Subset(self.__train_data, val_idx)

            train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

            print(f"Starting Fold {fold + 1}/{NUM_FOLDS}")

            # Train and evaluate the model for the current fold
            self.__train_and_evaluate(train_loader, val_loader)

            # Evaluate on the test set after each fold
            test_loader = DataLoader(self.__test_data, batch_size=BATCH_SIZE, shuffle=False)
            metrics = self.__evaluate_on_test_set(test_loader)
            fold_metrics.append(metrics)

            # Reset the model to untrained
            self.__trainer.reset_model()

        # Calculate average and standard deviation of each metric across all folds
        metric_keys = fold_metrics[0].keys()  # Assuming all metrics dictionaries have the same structure
        average_metrics = {key: np.mean([metric[key] for metric in fold_metrics]) for key in metric_keys}
        std_dev_metrics = {key: np.std([metric[key] for metric in fold_metrics]) for key in metric_keys}

        # Print average metrics and their standard deviations
        print("Average Metrics Over All Folds:")
        for key, value in average_metrics.items():
            print(f"{key}: {value:.4f} (±{std_dev_metrics[key]:.4f})")

        # Save metrics to CSV file
        save_results(fold_metrics, filename=f"{log_id}.csv")

# Graph Neural Network

In [54]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_size, num_labels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_size)
        self.bn1 = BatchNorm(hidden_size)
        self.conv2 = GCNConv(hidden_size, hidden_size)
        self.bn2 = BatchNorm(hidden_size)
        self.conv3 = GCNConv(hidden_size, num_labels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # First layer
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        # Second layer
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        # Third layer (output layer)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

In [55]:
model = GCN(NUM_FEATURES, HIDDEN_SIZE, NUM_LABELS)
TRAIN_DATASET, TEST_DATASET = torch.utils.data.random_split(AUGMENTED_DATASET, [1 - TEST_SIZE, TEST_SIZE],
                                                            generator=torch.Generator().manual_seed(RANDOM_SEED))

CrossValidator(Trainer(model), TRAIN_DATASET, TEST_DATASET).k_fold_cv(log_id="gcn")

Starting Fold 1/5

 --- Epoch 1/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.58it/s]



 TRAIN | Loss: 1.8350 | Accuracy: 0.4891, Precision: 0.3262, Recall: 0.4891, F1: 0.3394



Testing: 100%|██████████| 69/69 [00:01<00:00, 66.84it/s]


 VALID | Loss: 1.7140 | Accuracy: 0.4975, Precision: 0.2972, Recall: 0.4975, F1: 0.3501


 --- Epoch 2/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 30.17it/s]



 TRAIN | Loss: 1.7331 | Accuracy: 0.4985, Precision: 0.3641, Recall: 0.4985, F1: 0.3455



Testing: 100%|██████████| 69/69 [00:01<00:00, 58.73it/s]


 VALID | Loss: 1.6959 | Accuracy: 0.4965, Precision: 0.3606, Recall: 0.4965, F1: 0.3522


 --- Epoch 3/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.28it/s]



 TRAIN | Loss: 1.7033 | Accuracy: 0.4976, Precision: 0.3820, Recall: 0.4976, F1: 0.3474



Testing: 100%|██████████| 69/69 [00:01<00:00, 56.56it/s]


 VALID | Loss: 1.6499 | Accuracy: 0.4982, Precision: 0.3604, Recall: 0.4982, F1: 0.3506


 --- Epoch 4/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.46it/s]



 TRAIN | Loss: 1.6849 | Accuracy: 0.4977, Precision: 0.4074, Recall: 0.4977, F1: 0.3513



Testing: 100%|██████████| 69/69 [00:01<00:00, 53.82it/s]


 VALID | Loss: 1.6287 | Accuracy: 0.4948, Precision: 0.3882, Recall: 0.4948, F1: 0.3576


 --- Epoch 5/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 27.91it/s]



 TRAIN | Loss: 1.6695 | Accuracy: 0.4972, Precision: 0.4125, Recall: 0.4972, F1: 0.3508



Testing: 100%|██████████| 69/69 [00:01<00:00, 68.47it/s]


 VALID | Loss: 1.6051 | Accuracy: 0.4970, Precision: 0.4177, Recall: 0.4970, F1: 0.3585


 --- Epoch 6/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.17it/s]



 TRAIN | Loss: 1.6589 | Accuracy: 0.4968, Precision: 0.4309, Recall: 0.4968, F1: 0.3544



Testing: 100%|██████████| 69/69 [00:01<00:00, 66.38it/s]


 VALID | Loss: 1.5991 | Accuracy: 0.4974, Precision: 0.4206, Recall: 0.4974, F1: 0.3569


 --- Epoch 7/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 29.17it/s]



 TRAIN | Loss: 1.6506 | Accuracy: 0.4968, Precision: 0.4307, Recall: 0.4968, F1: 0.3555



Testing: 100%|██████████| 69/69 [00:01<00:00, 59.93it/s]


 VALID | Loss: 1.6158 | Accuracy: 0.4988, Precision: 0.3496, Recall: 0.4988, F1: 0.3404


 --- Epoch 8/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.58it/s]



 TRAIN | Loss: 1.6423 | Accuracy: 0.4964, Precision: 0.4515, Recall: 0.4964, F1: 0.3550



Testing: 100%|██████████| 69/69 [00:01<00:00, 66.49it/s]


 VALID | Loss: 1.5865 | Accuracy: 0.4980, Precision: 0.3813, Recall: 0.4980, F1: 0.3498


 --- Epoch 9/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.22it/s]



 TRAIN | Loss: 1.6350 | Accuracy: 0.4958, Precision: 0.4476, Recall: 0.4958, F1: 0.3557



Testing: 100%|██████████| 69/69 [00:01<00:00, 65.26it/s]


 VALID | Loss: 1.5747 | Accuracy: 0.4959, Precision: 0.4267, Recall: 0.4959, F1: 0.3594


 --- Epoch 10/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.42it/s]



 TRAIN | Loss: 1.6316 | Accuracy: 0.4963, Precision: 0.4544, Recall: 0.4963, F1: 0.3575



Testing: 100%|██████████| 69/69 [00:00<00:00, 69.44it/s]


 VALID | Loss: 1.5796 | Accuracy: 0.4981, Precision: 0.3542, Recall: 0.4981, F1: 0.3458


 --- Epoch 11/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.53it/s]



 TRAIN | Loss: 1.6263 | Accuracy: 0.4965, Precision: 0.4618, Recall: 0.4965, F1: 0.3575



Testing: 100%|██████████| 69/69 [00:01<00:00, 65.76it/s]


 VALID | Loss: 1.5710 | Accuracy: 0.4975, Precision: 0.3336, Recall: 0.4975, F1: 0.3492


 --- Epoch 12/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.40it/s]



 TRAIN | Loss: 1.6197 | Accuracy: 0.4980, Precision: 0.4629, Recall: 0.4980, F1: 0.3597



Testing: 100%|██████████| 69/69 [00:01<00:00, 63.15it/s]


 VALID | Loss: 1.5658 | Accuracy: 0.4967, Precision: 0.3711, Recall: 0.4967, F1: 0.3528


 --- Epoch 13/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.82it/s]



 TRAIN | Loss: 1.6209 | Accuracy: 0.4973, Precision: 0.4661, Recall: 0.4973, F1: 0.3574



Testing: 100%|██████████| 69/69 [00:01<00:00, 52.10it/s]


 VALID | Loss: 1.5679 | Accuracy: 0.4962, Precision: 0.4225, Recall: 0.4962, F1: 0.3617


 --- Epoch 14/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 29.29it/s]



 TRAIN | Loss: 1.6175 | Accuracy: 0.4963, Precision: 0.4670, Recall: 0.4963, F1: 0.3598



Testing: 100%|██████████| 69/69 [00:00<00:00, 70.38it/s]


 VALID | Loss: 1.5678 | Accuracy: 0.4985, Precision: 0.3886, Recall: 0.4985, F1: 0.3489


 --- Epoch 15/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.93it/s]



 TRAIN | Loss: 1.6135 | Accuracy: 0.4958, Precision: 0.4704, Recall: 0.4958, F1: 0.3605



Testing: 100%|██████████| 69/69 [00:00<00:00, 70.20it/s]


 VALID | Loss: 1.5614 | Accuracy: 0.4985, Precision: 0.3540, Recall: 0.4985, F1: 0.3457



Testing: 100%|██████████| 39/39 [00:00<00:00, 73.17it/s]



Test Set Evaluation | Loss: 1.5559 | Accuracy: 0.5056, Precision: 0.3684, Recall: 0.5056, F1: 0.3551

Starting Fold 2/5

 --- Epoch 1/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.22it/s]



 TRAIN | Loss: 1.6136 | Accuracy: 0.4963, Precision: 0.4697, Recall: 0.4963, F1: 0.3599



Testing: 100%|██████████| 69/69 [00:00<00:00, 71.29it/s]


 VALID | Loss: 1.5412 | Accuracy: 0.4995, Precision: 0.3946, Recall: 0.4995, F1: 0.3571


 --- Epoch 2/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 29.31it/s]



 TRAIN | Loss: 1.6101 | Accuracy: 0.4956, Precision: 0.4718, Recall: 0.4956, F1: 0.3605



Testing: 100%|██████████| 69/69 [00:01<00:00, 62.57it/s]


 VALID | Loss: 1.5403 | Accuracy: 0.5015, Precision: 0.4093, Recall: 0.5015, F1: 0.3472


 --- Epoch 3/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.42it/s]



 TRAIN | Loss: 1.6072 | Accuracy: 0.4957, Precision: 0.4738, Recall: 0.4957, F1: 0.3593



Testing: 100%|██████████| 69/69 [00:01<00:00, 61.82it/s]


 VALID | Loss: 1.5442 | Accuracy: 0.5020, Precision: 0.3644, Recall: 0.5020, F1: 0.3438


 --- Epoch 4/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.73it/s]



 TRAIN | Loss: 1.6049 | Accuracy: 0.4963, Precision: 0.4757, Recall: 0.4963, F1: 0.3588



Testing: 100%|██████████| 69/69 [00:00<00:00, 80.15it/s]


 VALID | Loss: 1.5580 | Accuracy: 0.4992, Precision: 0.3884, Recall: 0.4992, F1: 0.3622


 --- Epoch 5/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.70it/s]



 TRAIN | Loss: 1.6005 | Accuracy: 0.4980, Precision: 0.4852, Recall: 0.4980, F1: 0.3640



Testing: 100%|██████████| 69/69 [00:01<00:00, 55.94it/s]


 VALID | Loss: 1.5392 | Accuracy: 0.4991, Precision: 0.4314, Recall: 0.4991, F1: 0.3650


 --- Epoch 6/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.30it/s]



 TRAIN | Loss: 1.6029 | Accuracy: 0.4952, Precision: 0.4827, Recall: 0.4952, F1: 0.3614



Testing: 100%|██████████| 69/69 [00:01<00:00, 63.37it/s]


 VALID | Loss: 1.5352 | Accuracy: 0.5008, Precision: 0.3922, Recall: 0.5008, F1: 0.3494


 --- Epoch 7/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.18it/s]



 TRAIN | Loss: 1.6000 | Accuracy: 0.4960, Precision: 0.4820, Recall: 0.4960, F1: 0.3626



Testing: 100%|██████████| 69/69 [00:01<00:00, 67.35it/s]


 VALID | Loss: 1.5287 | Accuracy: 0.5013, Precision: 0.3674, Recall: 0.5013, F1: 0.3468


 --- Epoch 8/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 27.05it/s]



 TRAIN | Loss: 1.5977 | Accuracy: 0.4969, Precision: 0.4869, Recall: 0.4969, F1: 0.3629



Testing: 100%|██████████| 69/69 [00:00<00:00, 71.24it/s]


 VALID | Loss: 1.5180 | Accuracy: 0.5004, Precision: 0.4057, Recall: 0.5004, F1: 0.3548


 --- Epoch 9/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.49it/s]



 TRAIN | Loss: 1.5978 | Accuracy: 0.4960, Precision: 0.4867, Recall: 0.4960, F1: 0.3626



Testing: 100%|██████████| 69/69 [00:01<00:00, 62.09it/s]


 VALID | Loss: 1.5348 | Accuracy: 0.5018, Precision: 0.3683, Recall: 0.5018, F1: 0.3462


 --- Epoch 10/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.18it/s]



 TRAIN | Loss: 1.5982 | Accuracy: 0.4967, Precision: 0.4918, Recall: 0.4967, F1: 0.3629



Testing: 100%|██████████| 69/69 [00:01<00:00, 61.60it/s]


 VALID | Loss: 1.5330 | Accuracy: 0.5005, Precision: 0.4371, Recall: 0.5005, F1: 0.3578


 --- Epoch 11/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.77it/s]



 TRAIN | Loss: 1.5988 | Accuracy: 0.4955, Precision: 0.4884, Recall: 0.4955, F1: 0.3639



Testing: 100%|██████████| 69/69 [00:00<00:00, 71.29it/s]


 VALID | Loss: 1.5371 | Accuracy: 0.5013, Precision: 0.3645, Recall: 0.5013, F1: 0.3509


 --- Epoch 12/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 27.47it/s]



 TRAIN | Loss: 1.5939 | Accuracy: 0.4969, Precision: 0.4856, Recall: 0.4969, F1: 0.3629



Testing: 100%|██████████| 69/69 [00:01<00:00, 64.62it/s]


 VALID | Loss: 1.5375 | Accuracy: 0.5001, Precision: 0.4562, Recall: 0.5001, F1: 0.3647


 --- Epoch 13/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.04it/s]



 TRAIN | Loss: 1.5958 | Accuracy: 0.4958, Precision: 0.4916, Recall: 0.4958, F1: 0.3617



Testing: 100%|██████████| 69/69 [00:01<00:00, 66.58it/s]


 VALID | Loss: 1.5187 | Accuracy: 0.5012, Precision: 0.3744, Recall: 0.5012, F1: 0.3524


 --- Epoch 14/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.48it/s]



 TRAIN | Loss: 1.5974 | Accuracy: 0.4954, Precision: 0.4830, Recall: 0.4954, F1: 0.3620



Testing: 100%|██████████| 69/69 [00:01<00:00, 65.32it/s]


 VALID | Loss: 1.5168 | Accuracy: 0.4995, Precision: 0.4023, Recall: 0.4995, F1: 0.3609


 --- Epoch 15/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 29.40it/s]



 TRAIN | Loss: 1.5918 | Accuracy: 0.4957, Precision: 0.4926, Recall: 0.4957, F1: 0.3639



Testing: 100%|██████████| 69/69 [00:01<00:00, 68.66it/s]


 VALID | Loss: 1.5430 | Accuracy: 0.5010, Precision: 0.4045, Recall: 0.5010, F1: 0.3498



Testing: 100%|██████████| 39/39 [00:00<00:00, 58.04it/s]



Test Set Evaluation | Loss: 1.5450 | Accuracy: 0.5051, Precision: 0.3998, Recall: 0.5051, F1: 0.3543

Starting Fold 3/5

 --- Epoch 1/15 ---


Training: 100%|██████████| 276/276 [00:09<00:00, 28.61it/s]



 TRAIN | Loss: 1.5935 | Accuracy: 0.4950, Precision: 0.4837, Recall: 0.4950, F1: 0.3597



Testing: 100%|██████████| 69/69 [00:01<00:00, 56.60it/s]


 VALID | Loss: 1.5234 | Accuracy: 0.5020, Precision: 0.4375, Recall: 0.5020, F1: 0.3491


 --- Epoch 2/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.09it/s]



 TRAIN | Loss: 1.5926 | Accuracy: 0.4969, Precision: 0.4840, Recall: 0.4969, F1: 0.3628



Testing: 100%|██████████| 69/69 [00:01<00:00, 63.34it/s]


 VALID | Loss: 1.5171 | Accuracy: 0.5010, Precision: 0.4012, Recall: 0.5010, F1: 0.3554


 --- Epoch 3/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.07it/s]



 TRAIN | Loss: 1.5902 | Accuracy: 0.4958, Precision: 0.4821, Recall: 0.4958, F1: 0.3633



Testing: 100%|██████████| 69/69 [00:01<00:00, 54.95it/s]


 VALID | Loss: 1.5200 | Accuracy: 0.5021, Precision: 0.4714, Recall: 0.5021, F1: 0.3564


 --- Epoch 4/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 26.13it/s]



 TRAIN | Loss: 1.5906 | Accuracy: 0.4966, Precision: 0.4889, Recall: 0.4966, F1: 0.3642



Testing: 100%|██████████| 69/69 [00:01<00:00, 60.43it/s]


 VALID | Loss: 1.5131 | Accuracy: 0.5018, Precision: 0.4262, Recall: 0.5018, F1: 0.3545


 --- Epoch 5/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 27.07it/s]



 TRAIN | Loss: 1.5860 | Accuracy: 0.4959, Precision: 0.4888, Recall: 0.4959, F1: 0.3632



Testing: 100%|██████████| 69/69 [00:01<00:00, 60.14it/s]


 VALID | Loss: 1.5153 | Accuracy: 0.5017, Precision: 0.4027, Recall: 0.5017, F1: 0.3477


 --- Epoch 6/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.24it/s]



 TRAIN | Loss: 1.5881 | Accuracy: 0.4952, Precision: 0.4906, Recall: 0.4952, F1: 0.3609



Testing: 100%|██████████| 69/69 [00:01<00:00, 67.86it/s]


 VALID | Loss: 1.5246 | Accuracy: 0.5010, Precision: 0.4247, Recall: 0.5010, F1: 0.3597


 --- Epoch 7/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 27.52it/s]



 TRAIN | Loss: 1.5902 | Accuracy: 0.4958, Precision: 0.4824, Recall: 0.4958, F1: 0.3609



Testing: 100%|██████████| 69/69 [00:01<00:00, 61.90it/s]


 VALID | Loss: 1.5298 | Accuracy: 0.5013, Precision: 0.4006, Recall: 0.5013, F1: 0.3527


 --- Epoch 8/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.51it/s]



 TRAIN | Loss: 1.5856 | Accuracy: 0.4947, Precision: 0.4889, Recall: 0.4947, F1: 0.3624



Testing: 100%|██████████| 69/69 [00:01<00:00, 62.65it/s]


 VALID | Loss: 1.5298 | Accuracy: 0.5006, Precision: 0.4182, Recall: 0.5006, F1: 0.3596


 --- Epoch 9/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.88it/s]



 TRAIN | Loss: 1.5870 | Accuracy: 0.4960, Precision: 0.4901, Recall: 0.4960, F1: 0.3638



Testing: 100%|██████████| 69/69 [00:01<00:00, 57.84it/s]


 VALID | Loss: 1.5309 | Accuracy: 0.5013, Precision: 0.4417, Recall: 0.5013, F1: 0.3547


 --- Epoch 10/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.23it/s]



 TRAIN | Loss: 1.5869 | Accuracy: 0.4958, Precision: 0.4863, Recall: 0.4958, F1: 0.3620



Testing: 100%|██████████| 69/69 [00:01<00:00, 54.99it/s]


 VALID | Loss: 1.5246 | Accuracy: 0.5010, Precision: 0.4478, Recall: 0.5010, F1: 0.3513


 --- Epoch 11/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.00it/s]



 TRAIN | Loss: 1.5864 | Accuracy: 0.4958, Precision: 0.4965, Recall: 0.4958, F1: 0.3646



Testing: 100%|██████████| 69/69 [00:01<00:00, 42.57it/s]


 VALID | Loss: 1.5093 | Accuracy: 0.5013, Precision: 0.4196, Recall: 0.5013, F1: 0.3595


 --- Epoch 12/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.99it/s]



 TRAIN | Loss: 1.5858 | Accuracy: 0.4952, Precision: 0.4898, Recall: 0.4952, F1: 0.3626



Testing: 100%|██████████| 69/69 [00:01<00:00, 54.62it/s]


 VALID | Loss: 1.5353 | Accuracy: 0.5019, Precision: 0.4260, Recall: 0.5019, F1: 0.3560


 --- Epoch 13/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.09it/s]



 TRAIN | Loss: 1.5823 | Accuracy: 0.4959, Precision: 0.4895, Recall: 0.4959, F1: 0.3656



Testing: 100%|██████████| 69/69 [00:01<00:00, 57.01it/s]


 VALID | Loss: 1.5149 | Accuracy: 0.5010, Precision: 0.4121, Recall: 0.5010, F1: 0.3525


 --- Epoch 14/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 23.24it/s]



 TRAIN | Loss: 1.5844 | Accuracy: 0.4962, Precision: 0.4945, Recall: 0.4962, F1: 0.3654



Testing: 100%|██████████| 69/69 [00:01<00:00, 54.99it/s]


 VALID | Loss: 1.5187 | Accuracy: 0.5018, Precision: 0.3520, Recall: 0.5018, F1: 0.3431


 --- Epoch 15/15 ---


Training: 100%|██████████| 276/276 [00:11<00:00, 24.81it/s]



 TRAIN | Loss: 1.5801 | Accuracy: 0.4964, Precision: 0.4919, Recall: 0.4964, F1: 0.3648



Testing: 100%|██████████| 69/69 [00:01<00:00, 52.05it/s]


 VALID | Loss: 1.5073 | Accuracy: 0.5009, Precision: 0.4364, Recall: 0.5009, F1: 0.3552



Testing: 100%|██████████| 39/39 [00:00<00:00, 54.96it/s]



Test Set Evaluation | Loss: 1.5108 | Accuracy: 0.5042, Precision: 0.4316, Recall: 0.5042, F1: 0.3607

Starting Fold 4/5

 --- Epoch 1/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.33it/s]



 TRAIN | Loss: 1.5811 | Accuracy: 0.4963, Precision: 0.4876, Recall: 0.4963, F1: 0.3656



Testing: 100%|██████████| 69/69 [00:01<00:00, 53.87it/s]


 VALID | Loss: 1.5132 | Accuracy: 0.4979, Precision: 0.3862, Recall: 0.4979, F1: 0.3462


 --- Epoch 2/15 ---


Training: 100%|██████████| 276/276 [00:10<00:00, 25.14it/s]



 TRAIN | Loss: 1.5803 | Accuracy: 0.4960, Precision: 0.4980, Recall: 0.4960, F1: 0.3634



Testing: 100%|██████████| 69/69 [00:01<00:00, 64.00it/s]


 VALID | Loss: 1.5284 | Accuracy: 0.4962, Precision: 0.3947, Recall: 0.4962, F1: 0.3536


 --- Epoch 3/15 ---


Training:  64%|██████▍   | 177/276 [00:06<00:03, 28.02it/s]


KeyboardInterrupt: 

# Traditional Models

In [None]:
class ClassifiersPoolEvaluator:
    """
    ClassifiersPoolEvaluator class for evaluating a pool of classifiers using graph features and stratified k-fold cross-validation.
    """

    def __init__(self, graph_dataset):
        """
        Initialize the ClassifiersPoolEvaluator with a dictionary of classifiers.

        :param graph_dataset: The graph dataset to be used for evaluation.
        """
        # Define a dictionary of classifiers to evaluate
        self.__classifiers = {
            "svm": OneVsRestClassifier(SVC(kernel='linear', probability=True)),
            "random_forest": OneVsRestClassifier(RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)),
            "gradient_boosting": OneVsRestClassifier(
                GradientBoostingClassifier(n_estimators=100, learning_rate=LR, max_depth=3)),
            "logistic_regression": OneVsRestClassifier(LogisticRegression(random_state=RANDOM_SEED)),
            "knn": OneVsRestClassifier(KNeighborsClassifier(n_neighbors=5)),
            "xgboost": OneVsRestClassifier(
                XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', random_state=RANDOM_SEED))
        }

        # Extract features and labels from the graph dataset
        self.X, self.y = self._extract_features_and_labels(graph_dataset)

    def _extract_features_and_labels(self, graph_dataset) -> (np.ndarray, np.ndarray):
        """
        Extract features and labels from the graph dataset.

        :param graph_dataset: The graph dataset.
        :return: A tuple of features and labels as numpy arrays.
        """
        features = []
        labels = []

        for data in graph_dataset:
            features.append(data.x.numpy())  # Assuming `data.x` contains the features as a tensor
            labels.append(data.y[0].numpy())  # Assuming `data.y` contains the labels as a tensor

        # Determine the maximum shape for padding over the second dimension
        max_length = max(f.shape[0] for f in features)

        # Pad features to ensure they all have the same shape over the second dimension
        padded_features = np.array([np.pad(f, ((0, max_length - f.shape[0]), (0, 0)), 'constant') for f in features])

        # Convert labels to a numpy array and flatten if necessary
        labels = np.array(labels)

        return padded_features, labels

    def __evaluate_fold(self, classifier: OneVsRestClassifier, train_index: List[int], test_index: List[int],
                        fold_num: int) -> Dict[str, float]:
        """
        Evaluate a classifier on a single fold of cross-validation.

        :param classifier: The classifier to be evaluated.
        :param train_index: Indices for the training data.
        :param test_index: Indices for the test data.
        :param fold_num: The fold number.
        :return: A dictionary of computed metrics.
        """
        X_train, X_test = self.X[train_index], self.X[test_index]
        y_train, y_test = self.y[train_index], self.y[test_index]

        # Flatten the features to be compatible with classifiers
        X_train = X_train.reshape(X_train.shape[0], -1)
        X_test = X_test.reshape(X_test.shape[0], -1)

        # Train the classifier on the training data
        classifier.fit(X_train, y_train)
        # Make predictions on the test data
        predictions = classifier.predict(X_test)

        # Compute metrics using the provided utility function
        metrics = compute_metrics(y_test, predictions)
        print(f"Results for fold {fold_num} | "
              f"Precision: {metrics['precision']:.4f}, "
              f"Recall: {metrics['recall']:.4f}, "
              f"F1: {metrics['f1']:.4f}")
        return metrics

    def __stratified_k_fold_cv(self, classifier: OneVsRestClassifier) -> pd.DataFrame:
        """
        Perform stratified k-fold cross-validation on a given classifier.

        :param classifier: The classifier to be evaluated.
        :return: A DataFrame containing the results of each fold.
        """
        skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=RANDOM_SEED)
        # Evaluate the classifier on each fold and collect the results
        results = []
        for fold_num, (train_index, test_index) in enumerate(skf.split(self.X, self.y), 1):
            metrics = self.__evaluate_fold(classifier, train_index, test_index, fold_num)
            results.append(metrics)
        # Return the results as a DataFrame
        return pd.DataFrame(results)

    def pool_evaluation(self) -> None:
        """
        Run the evaluation for each classifier defined in self.__classifiers.
        """
        # Run the evaluation for each classifier defined in self.__classifiers
        for classifier_name, classifier in self.__classifiers.items():
            print(f"\nTesting classifier: {classifier_name}\n")
            # Evaluate the classifier and get the metrics DataFrame
            metrics_df = self.__stratified_k_fold_cv(classifier)
            # Save the results using the provided utility function
            save_results(metrics_df, f"{classifier_name}.csv")

In [None]:
evaluator = ClassifiersPoolEvaluator(DATASET)
evaluator.pool_evaluation()