In [7]:
import hashlib
import json
import os

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
def hash_feature(value, num_bins=1000):
    """Helper function to hash a value into a fixed number of bins."""
    return int(hashlib.md5(str(value).encode()).hexdigest(), 16) % num_bins


def extract_features(node):
    # Initialize features with default values
    type_feature = [0]
    name_feature = [0]
    value_feature = [0]
    src_feature = [0, 0]
    type_desc_features = [0, 0]
    state_mutability_feature = [0]
    visibility_feature = [0]

    # Extract basic features
    node_type = node.get('nodeType', 'Unknown')
    type_feature = [hash_feature(node_type)]

    # Extract additional features if they exist
    if 'name' in node:
        name_feature = [hash_feature(node.get('name', ''))]
    if 'value' in node:
        value_feature = [hash_feature(node.get('value', ''))]

    # Extract src features (start, end, and length if available)
    if 'src' in node:
        start, length, *_ = map(int, node['src'].split(':'))
        src_feature = [start, length]

    # Extract typeDescriptions features if they exist
    if 'typeDescriptions' in node:
        type_desc = node['typeDescriptions']
        type_desc_features = [
            hash_feature(type_desc.get('typeString', '')),
            hash_feature(type_desc.get('typeIdentifier', ''))
        ]

    # Extract stateMutability if it exists
    if 'stateMutability' in node:
        state_mutability_feature = [hash_feature(node.get('stateMutability', ''))]

    # Extract visibility if it exists
    if 'visibility' in node:
        visibility_feature = [hash_feature(node.get('visibility', ''))]

    # Combine all features into a single feature vector
    features = (type_feature + name_feature + value_feature +
                src_feature + type_desc_features +
                state_mutability_feature + visibility_feature)
    return features


def generate_label_map(ast_directory):
    label_map = {}
    label_index = 0
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            label_map[category] = label_index
            label_index += 1
    return label_map


def load_data(ast_directory, label_map):
    dataset = []
    labels = []
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            for root, _, files in os.walk(category_path):
                for file in files:
                    if file.endswith('.json'):
                        filepath = os.path.join(root, file)
                        with open(filepath, 'r') as f:
                            ast = json.load(f)
                        data = ast_to_graph(ast)
                        label = label_map[category]
                        data.y = torch.tensor([label] * data.x.size(0), dtype=torch.long)  # Assign label to all nodes
                        dataset.append(data)
                        labels.append(label)
    print(f"Loaded {len(dataset)} samples from {ast_directory}")
    return dataset, labels


# Graph Neural Network

In [6]:
def add_masks_to_data(data, train_ratio=0.8, val_ratio=0.1):
    num_nodes = data.x.size(0)
    indices = torch.randperm(num_nodes)

    train_size = int(num_nodes * train_ratio)
    val_size = int(num_nodes * val_ratio)

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


def ast_to_graph(ast_json):
    graph = nx.DiGraph()
    node_id = 0

    def add_nodes_edges(node, parent=None):
        nonlocal node_id
        current_node_id = node_id
        graph.add_node(current_node_id, features=extract_features(node))
        if parent is not None:
            graph.add_edge(parent, current_node_id)
        node_id += 1
        for key, value in node.items():
            if isinstance(value, dict):
                add_nodes_edges(value, current_node_id)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        add_nodes_edges(item, current_node_id)

    add_nodes_edges(ast_json)
    edge_index = torch.tensor(list(graph.edges)).t().contiguous()
    x = torch.stack([torch.tensor(graph.nodes[n]['features'], dtype=torch.float) for n in graph.nodes])

    data = Data(x=x, edge_index=edge_index)
    return add_masks_to_data(data)


def stratified_split(dataset, labels):
    # Split data into training + validation and test data
    train_val_data, test_data, train_val_labels, test_labels = train_test_split(
        dataset, labels, test_size=0.1, stratify=labels, random_state=42)

    # Split training + validation into actual training and validation data
    train_data, val_data, train_labels, val_labels = train_test_split(
        train_val_data, train_val_labels, test_size=0.11, stratify=train_val_labels,
        random_state=42)  # 0.11 * 0.9 ≈ 0.1

    return train_data, val_data, test_data


class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out[data.train_mask].argmax(dim=1)
        y_true.extend(data.y[data.train_mask].cpu().numpy())
        y_pred.extend(pred.cpu().numpy())
        y_score.extend(F.softmax(out[data.train_mask], dim=1).cpu().detach().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return total_loss / len(loader), accuracy, precision, recall, f1, roc_auc


def evaluate(model, loader, mask_type):
    model.eval()
    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            mask = getattr(data, mask_type)
            pred = out[mask].argmax(dim=1)
            y_true.extend(data.y[mask].cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
            y_score.extend(F.softmax(out[mask], dim=1).cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


def main():
    ast_directory = '../dataset/aisc/ast'
    label_map = generate_label_map(ast_directory)
    dataset, labels = load_data(ast_directory, label_map)

    if len(dataset) == 0:
        print("No data loaded. Please check the dataset directory and files.")
        return

    train_data, val_data, test_data = stratified_split(dataset, labels)

    print(f"Train data size: {len(train_data)}")
    print(f"Validation data size: {len(val_data)}")
    print(f"Test data size: {len(test_data)}")

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)

    num_features = train_data[0].x.shape[1]
    num_classes = len(label_map)

    model = GCN(num_features=num_features, hidden_channels=64, num_classes=num_classes).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1500):
        train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc = train(model, train_loader, optimizer,
                                                                                      criterion)
        val_acc, val_prec, val_rec, val_f1, val_roc_auc = evaluate(model, val_loader, 'val_mask')

        print('-----------------------------------------------------------------------------------------------------')
        print(f'EPOCH: {epoch + 1} -> Loss: {train_loss:.4f}')
        print(
            f'(Train) Accuracy: {train_acc:.4f}, Precision: {train_prec:.4f}, Recall: {train_rec:.4f}, F1: {train_f1:.4f}, AUC: {train_roc_auc:.4f}')
        print(
            f'(Valid) Accuracy: {val_acc:.4f}, Precision: {val_prec:.4f}, Recall: {val_rec:.4f}, F1: {val_f1:.4f}, AUC: {val_roc_auc:.4f}')
        print('-----------------------------------------------------------------------------------------------------')

    print('-----------------------------------------------------------------------------------------------------')
    test_acc, test_prec, test_rec, test_f1, test_roc_auc = evaluate(model, test_loader, 'test_mask')
    print(
        f'(Test) Accuracy: {test_acc:.4f}, Precision: {test_prec:.4f}, Recall: {test_rec:.4f}, F1: {test_f1:.4f}, AUC: {test_roc_auc:.4f}')


if __name__ == "__main__":
    main()


Loaded 1020 samples from ../dataset/aisc/ast
Train data size: 817
Validation data size: 101
Test data size: 102


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 1 -> Loss: 217.2883
(Train) Accuracy: 0.1275, Precision: 0.1082, Recall: 0.1275, F1: 0.0913, AUC: 0.5105
(Valid) Accuracy: 0.1621, Precision: 0.0892, Recall: 0.1621, F1 Score: 0.0993, AUC: 0.5370
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 2 -> Loss: 202.4989
(Train) Accuracy: 0.1294, Precision: 0.1142, Recall: 0.1294, F1: 0.0992, AUC: 0.5139
(Valid) Accuracy: 0.1614, Precision: 0.2386, Recall: 0.1614, F1 Score: 0.1059, AUC: 0.5402
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 3 -> Loss: 190.6906
(Train) Accuracy: 0.1294, Precision: 0.1107, Recall: 0.1294, F1: 0.1033, AUC: 0.5151
(Valid) Accuracy: 0.1628, Precision: 0.1324, Recall: 0.1628, F1 Score: 0.1168, AUC: 0.5443
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 4 -> Loss: 179.4853
(Train) Accuracy: 0.1288, Precision: 0.1133, Recall: 0.1288, F1: 0.1071, AUC: 0.5154
(Valid) Accuracy: 0.1580, Precision: 0.1228, Recall: 0.1580, F1 Score: 0.1261, AUC: 0.5469
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 5 -> Loss: 172.2464
(Train) Accuracy: 0.1241, Precision: 0.1078, Recall: 0.1241, F1: 0.1054, AUC: 0.5138
(Valid) Accuracy: 0.1601, Precision: 0.1295, Recall: 0.1601, F1 Score: 0.1333, AUC: 0.5489
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 6 -> Loss: 164.6031
(Train) Accuracy: 0.1247, Precision: 0.1083, Recall: 0.1247, F1: 0.1074, AUC: 0.5146
(Valid) Accuracy: 0.1531, Precision: 0.1213, Recall: 0.1531, F1 Score: 0.1290, AUC: 0.5508
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 7 -> Loss: 158.5335
(Train) Accuracy: 0.1216, Precision: 0.1082, Recall: 0.12

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 784 -> Loss: 2.1372
(Train) Accuracy: 0.1894, Precision: 0.3043, Recall: 0.1894, F1: 0.1321, AUC: 0.6087
(Valid) Accuracy: 0.2408, Precision: 0.2792, Recall: 0.2408, F1 Score: 0.1757, AUC: 0.6660
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 785 -> Loss: 2.1372
(Train) Accuracy: 0.1888, Precision: 0.2971, Recall: 0.1888, F1: 0.1297, AUC: 0.6110
(Valid) Accuracy: 0.2450, Precision: 0.3778, Recall: 0.2450, F1 Score: 0.1780, AUC: 0.6658
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 786 -> Loss: 2.1401
(Train) Accuracy: 0.1888, Precision: 0.2998, Recall: 0.1888, F1: 0.1302, AUC: 0.6088
(Valid) Accuracy: 0.2401, Precision: 0.2773, Recall: 0.2401, F1 Score: 0.1730, AUC: 0.6654
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 787 -> Loss: 2.1365
(Train) Accuracy: 0.1887, Precision: 0.2980, Recall: 0.1887, F1: 0.1302, AUC: 0.6144
(Valid) Accuracy: 0.2422, Precision: 0.2803, Recall: 0.2422, F1 Score: 0.1736, AUC: 0.6666
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 788 -> Loss: 2.1372
(Train) Accuracy: 0.1882, Precision: 0.2964, Recall: 0.18

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 791 -> Loss: 2.1364
(Train) Accuracy: 0.1872, Precision: 0.2935, Recall: 0.1872, F1: 0.1287, AUC: 0.6145
(Valid) Accuracy: 0.2436, Precision: 0.2791, Recall: 0.2436, F1 Score: 0.1724, AUC: 0.6665
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 792 -> Loss: 2.1359
(Train) Accuracy: 0.1889, Precision: 0.2903, Recall: 0.1889, F1: 0.1301, AUC: 0.6132
(Valid) Accuracy: 0.2429, Precision: 0.2844, Recall: 0.2429, F1 Score: 0.1714, AUC: 0.6661
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 793 -> Loss: 2.1387
(Train) Accuracy: 0.1893, Precision: 0.2998, Recall: 0.1893, F1: 0.1315, AUC: 0.6122
(Valid) Accuracy: 0.2394, Precision: 0.2788, Recall: 0.2394, F1 Score: 0.1693, AUC: 0.6667
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 794 -> Loss: 2.1343
(Train) Accuracy: 0.1890, Precision: 0.2986, Recall: 0.1890, F1: 0.1324, AUC: 0.6133
(Valid) Accuracy: 0.2422, Precision: 0.2981, Recall: 0.2422, F1 Score: 0.1741, AUC: 0.6655
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 795 -> Loss: 2.1353
(Train) Accuracy: 0.1883, Precision: 0.2915, Recall: 0.18

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 796 -> Loss: 2.1383
(Train) Accuracy: 0.1890, Precision: 0.3006, Recall: 0.1890, F1: 0.1307, AUC: 0.6117
(Valid) Accuracy: 0.2429, Precision: 0.2853, Recall: 0.2429, F1 Score: 0.1743, AUC: 0.6662
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 797 -> Loss: 2.1316
(Train) Accuracy: 0.1906, Precision: 0.2955, Recall: 0.1906, F1: 0.1325, AUC: 0.6164
(Valid) Accuracy: 0.2436, Precision: 0.3709, Recall: 0.2436, F1 Score: 0.1750, AUC: 0.6670
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 798 -> Loss: 2.1298
(Train) Accuracy: 0.1907, Precision: 0.3037, Recall: 0.19

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 817 -> Loss: 2.1351
(Train) Accuracy: 0.1898, Precision: 0.3036, Recall: 0.1898, F1: 0.1327, AUC: 0.6187
(Valid) Accuracy: 0.2373, Precision: 0.2762, Recall: 0.2373, F1 Score: 0.1680, AUC: 0.6675
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 818 -> Loss: 2.1391
(Train) Accuracy: 0.1898, Precision: 0.3038, Recall: 0.1898, F1: 0.1322, AUC: 0.6103
(Valid) Accuracy: 0.2380, Precision: 0.2774, Recall: 0.2380, F1 Score: 0.1666, AUC: 0.6677
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 819 -> Loss: 2.1380
(Train) Accuracy: 0.1916, Precision: 0.3062, Recall: 0.1916, F1: 0.1356, AUC: 0.6149
(Valid) Accuracy: 0.2394, Precision: 0.2865, Recall: 0.2394, F1 Score: 0.1681, AUC: 0.6679
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 820 -> Loss: 2.1354
(Train) Accuracy: 0.1898, Precision: 0.3021, Recall: 0.1898, F1: 0.1337, AUC: 0.6172
(Valid) Accuracy: 0.2373, Precision: 0.2855, Recall: 0.2373, F1 Score: 0.1703, AUC: 0.6670
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 821 -> Loss: 2.1380
(Train) Accuracy: 0.1893, Precision: 0.3067, Recall: 0.1893, F1: 0.1327, AUC: 0.6154
(Valid) Accuracy: 0.2352, Precision: 0.2725, Recall: 0.2352, F1 Score: 0.1657, AUC: 0.6675
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 822 -> Loss: 2.1363
(Train) Accuracy: 0.1904, Precision: 0.3035, Recall: 0.1904, F1: 0.1342, AUC: 0.6133
(Valid) Accuracy: 0.2380, Precision: 0.2783, Recall: 0.2380, F1 Score: 0.1668, AUC: 0.6684
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 823 -> Loss: 2.1336
(Train) Accuracy: 0.1886, Precision: 0.2996, Recall: 0.1886, F1: 0.1319, AUC: 0.6192
(Valid) Accuracy: 0.2352, Precision: 0.3247, Recall: 0.2352, F1 Score: 0.1658, AUC: 0.6680
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 824 -> Loss: 2.1328
(Train) Accuracy: 0.1897, Precision: 0.3047, Recall: 0.1897, F1: 0.1330, AUC: 0.6153
(Valid) Accuracy: 0.2380, Precision: 0.2841, Recall: 0.2380, F1 Score: 0.1679, AUC: 0.6681
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 825 -> Loss: 2.1358
(Train) Accuracy: 0.1890, Precision: 0.2980, Recall: 0.1890, F1: 0.1317, AUC: 0.6183
(Valid) Accuracy: 0.2373, Precision: 0.3248, Recall: 0.2373, F1 Score: 0.1663, AUC: 0.6676
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 826 -> Loss: 2.1413
(Train) Accuracy: 0.1893, Precision: 0.2996, Recall: 0.1893, F1: 0.1324, AUC: 0.6179
(Valid) Accuracy: 0.2457, Precision: 0.2997, Recall: 0.2457, F1 Score: 0.1750, AUC: 0.6686
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 827 -> Loss: 2.1340
(Train) Accuracy: 0.1897, Precision: 0.3015, Recall: 0.1897, F1: 0.1321, AUC: 0.6164
(Valid) Accuracy: 0.2408, Precision: 0.2902, Recall: 0.2408, F1 Score: 0.1667, AUC: 0.6691
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 828 -> Loss: 2.1307
(Train) Accuracy: 0.1912, Precision: 0.3006, Recall: 0.19

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 833 -> Loss: 2.1349
(Train) Accuracy: 0.1901, Precision: 0.3046, Recall: 0.1901, F1: 0.1333, AUC: 0.6153
(Valid) Accuracy: 0.2373, Precision: 0.2760, Recall: 0.2373, F1 Score: 0.1634, AUC: 0.6691
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 834 -> Loss: 2.1372
(Train) Accuracy: 0.1911, Precision: 0.3041, Recall: 0.1911, F1: 0.1332, AUC: 0.6157
(Valid) Accuracy: 0.2387, Precision: 0.2842, Recall: 0.2387, F1 Score: 0.1658, AUC: 0.6697
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 835 -> Loss: 2.1346
(Train) Accuracy: 0.1889, Precision: 0.2957, Recall: 0.18

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 924 -> Loss: 2.1306
(Train) Accuracy: 0.1941, Precision: 0.2909, Recall: 0.1941, F1: 0.1380, AUC: 0.6224
(Valid) Accuracy: 0.2443, Precision: 0.2735, Recall: 0.2443, F1 Score: 0.1723, AUC: 0.6734
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 925 -> Loss: 2.1281
(Train) Accuracy: 0.1929, Precision: 0.2861, Recall: 0.1929, F1: 0.1371, AUC: 0.6257
(Valid) Accuracy: 0.2457, Precision: 0.2913, Recall: 0.2457, F1 Score: 0.1743, AUC: 0.6734
-----------------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 926 -> Loss: 2.1323
(Train) Accuracy: 0.1938, Precision: 0.2868, Recall: 0.1938, F1: 0.1360, AUC: 0.6207
(Valid) Accuracy: 0.2498, Precision: 0.2767, Recall: 0.2498, F1 Score: 0.1786, AUC: 0.6738
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 927 -> Loss: 2.1252
(Train) Accuracy: 0.1938, Precision: 0.2874, Recall: 0.1938, F1: 0.1371, AUC: 0.6288
(Valid) Accuracy: 0.2484, Precision: 0.3021, Recall: 0.2484, F1 Score: 0.1757, AUC: 0.6743
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 928 -> Loss: 2.1277
(Train) Accuracy: 0.1928, Precision: 0.2870, Recall: 0.19

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 931 -> Loss: 2.1287
(Train) Accuracy: 0.1905, Precision: 0.2775, Recall: 0.1905, F1: 0.1335, AUC: 0.6217
(Valid) Accuracy: 0.2484, Precision: 0.2801, Recall: 0.2484, F1 Score: 0.1770, AUC: 0.6746
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 932 -> Loss: 2.1255
(Train) Accuracy: 0.1928, Precision: 0.2844, Recall: 0.1928, F1: 0.1365, AUC: 0.6252
(Valid) Accuracy: 0.2477, Precision: 0.3171, Recall: 0.2477, F1 Score: 0.1757, AUC: 0.6735
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 933 -> Loss: 2.1321
(Train) Accuracy: 0.1932, Precision: 0.2885, Recall: 0.19

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 948 -> Loss: 2.1284
(Train) Accuracy: 0.1936, Precision: 0.2826, Recall: 0.1936, F1: 0.1374, AUC: 0.6274
(Valid) Accuracy: 0.2457, Precision: 0.2591, Recall: 0.2457, F1 Score: 0.1743, AUC: 0.6742
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 949 -> Loss: 2.1287
(Train) Accuracy: 0.1925, Precision: 0.2752, Recall: 0.1925, F1: 0.1337, AUC: 0.6234
(Valid) Accuracy: 0.2491, Precision: 0.2943, Recall: 0.2491, F1 Score: 0.1769, AUC: 0.6745
-----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
EPOCH: 950 -> Loss: 2.1249
(Train) Accuracy: 0.1966, Precision: 0.2920, Recall: 0.19

KeyboardInterrupt: 

# Traditional Models

In [9]:
def pad_features_to_fixed_length(features, fixed_length=4716):
    """Pad feature vectors to a fixed length."""
    padding_length = fixed_length - len(features)
    if padding_length > 0:
        return np.concatenate([features, np.zeros(padding_length)])
    else:
        return features[:fixed_length]


def extract_graph_features(graph, fixed_length=4716):
    """Flatten the features of the graph into a single vector and pad to fixed_length."""
    node_features = [graph.nodes[n]['features'] for n in graph.nodes]
    max_length = max(len(f) for f in node_features)
    padded_features = [pad_features_to_fixed_length(f, max_length) for f in node_features]
    flat_features = np.array(padded_features).flatten()
    return pad_features_to_fixed_length(flat_features, fixed_length)


def prepare_dataset(dataset, fixed_length=4716):
    """Prepare the dataset for training a Random Forest classifier."""
    features = []
    labels = []

    for data in dataset:
        graph = nx.DiGraph()
        for node_id in range(data.x.size(0)):
            graph.add_node(node_id, features=data.x[node_id].numpy())
        graph_features = extract_graph_features(graph, fixed_length)
        features.append(graph_features)

        if data.y.numel() == 1:
            labels.append(data.y.item())
        else:
            labels.append(data.y[0].item())

    return np.array(features), np.array(labels)


def train_random_forest(X_train, y_train):
    """Train a Random Forest classifier."""
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)
    return clf


def evaluate_random_forest(clf, X, y):
    """Evaluate the Random Forest classifier."""
    y_pred = clf.predict(X)
    y_score = clf.predict_proba(X)

    accuracy = accuracy_score(y, y_pred)
    precision = precision_score(y, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y, y_pred, average='weighted', zero_division=0)
    roc_auc = roc_auc_score(y, y_score, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ast_directory = '../dataset/aisc/ast'
    label_map = generate_label_map(ast_directory)
    dataset, labels = load_data(ast_directory, label_map)

    if len(dataset) == 0:
        print("No data loaded. Please check the dataset directory and files.")
        return

    # Prepare dataset for Random Forest with a fixed length of 4716 features
    fixed_length = 4716
    features, labels = prepare_dataset(dataset, fixed_length)

    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
    all_fold_metrics = {
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': [],
        'roc_auc': []
    }

    for fold, (train_index, test_index) in enumerate(skf.split(features, labels)):
        X_train, X_test = features[train_index], features[test_index]
        y_train, y_test = labels[train_index], labels[test_index]

        # Train Random Forest classifier
        clf = train_random_forest(X_train, y_train)

        # Evaluate on the test set
        test_acc, test_prec, test_rec, test_f1, test_roc_auc = evaluate_random_forest(clf, X_test, y_test)

        # Store the results
        all_fold_metrics['accuracy'].append(test_acc)
        all_fold_metrics['precision'].append(test_prec)
        all_fold_metrics['recall'].append(test_rec)
        all_fold_metrics['f1'].append(test_f1)
        all_fold_metrics['roc_auc'].append(test_roc_auc)

        print(
            f"Fold {fold + 1} - (Test) Accuracy: {test_acc:.4f}, Precision: {test_prec:.4f}, Recall: {test_rec:.4f}, F1 Score: {test_f1:.4f}, ROC-AUC: {test_roc_auc:.4f}")

    # Calculate average metrics
    avg_accuracy = np.mean(all_fold_metrics['accuracy'])
    avg_precision = np.mean(all_fold_metrics['precision'])
    avg_recall = np.mean(all_fold_metrics['recall'])
    avg_f1 = np.mean(all_fold_metrics['f1'])
    avg_roc_auc = np.mean(all_fold_metrics['roc_auc'])

    print(f"\nAverage Metrics over 10 folds:")
    print(f"Accuracy .... : {avg_accuracy:.4f}")
    print(f"Precision ... : {avg_precision:.4f}")
    print(f"Recall ...... : {avg_recall:.4f}")
    print(f"F1 Score .... : {avg_f1:.4f}")
    print(f"ROC-AUC ..... : {avg_roc_auc:.4f}")


if __name__ == "__main__":
    main()



Loaded 1020 samples from ../dataset/aisc/ast
Fold 1 - (Test) Accuracy: 0.8725, Precision: 0.9005, Recall: 0.8725, F1 Score: 0.8707, ROC-AUC: 0.9930
Fold 2 - (Test) Accuracy: 0.9412, Precision: 0.9475, Recall: 0.9412, F1 Score: 0.9408, ROC-AUC: 0.9962
Fold 3 - (Test) Accuracy: 0.9118, Precision: 0.9189, Recall: 0.9118, F1 Score: 0.9119, ROC-AUC: 0.9958
Fold 4 - (Test) Accuracy: 0.9412, Precision: 0.9506, Recall: 0.9412, F1 Score: 0.9399, ROC-AUC: 0.9961
Fold 5 - (Test) Accuracy: 0.9608, Precision: 0.9626, Recall: 0.9608, F1 Score: 0.9607, ROC-AUC: 0.9988
Fold 6 - (Test) Accuracy: 0.9412, Precision: 0.9460, Recall: 0.9412, F1 Score: 0.9404, ROC-AUC: 0.9924
Fold 7 - (Test) Accuracy: 0.9216, Precision: 0.9321, Recall: 0.9216, F1 Score: 0.9215, ROC-AUC: 0.9942
Fold 8 - (Test) Accuracy: 0.9314, Precision: 0.9378, Recall: 0.9314, F1 Score: 0.9314, ROC-AUC: 0.9944
Fold 9 - (Test) Accuracy: 0.9118, Precision: 0.9237, Recall: 0.9118, F1 Score: 0.9116, ROC-AUC: 0.9948
Fold 10 - (Test) Accuracy: 0