In [353]:
import json
import os

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def extract_features(node):
    node_type = node.get('nodeType', 'Unknown')
    depth_feature = [hash(node_type) % 1000]
    return depth_feature


def add_masks_to_data(data, train_ratio=0.8, val_ratio=0.1):
    num_nodes = data.x.size(0)
    indices = torch.randperm(num_nodes)

    train_size = int(num_nodes * train_ratio)
    val_size = int(num_nodes * val_ratio)

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


def ast_to_graph(ast_json):
    graph = nx.DiGraph()
    node_id = 0

    def add_nodes_edges(node, parent=None):
        nonlocal node_id
        current_node_id = node_id
        graph.add_node(current_node_id, features=extract_features(node))
        if parent is not None:
            graph.add_edge(parent, current_node_id)
        node_id += 1
        for key, value in node.items():
            if isinstance(value, dict):
                add_nodes_edges(value, current_node_id)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        add_nodes_edges(item, current_node_id)

    add_nodes_edges(ast_json)
    edge_index = torch.tensor(list(graph.edges)).t().contiguous()
    x = torch.stack([torch.tensor(graph.nodes[n]['features'], dtype=torch.float) for n in graph.nodes])

    data = Data(x=x, edge_index=edge_index)
    return add_masks_to_data(data)


def generate_label_map(ast_directory):
    label_map = {}
    label_index = 0
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            label_map[category] = label_index
            label_index += 1
    return label_map


def load_data(ast_directory, label_map):
    dataset = []
    labels = []
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            for root, _, files in os.walk(category_path):
                for file in files:
                    if file.endswith('.json'):
                        filepath = os.path.join(root, file)
                        with open(filepath, 'r') as f:
                            ast = json.load(f)
                        data = ast_to_graph(ast)
                        label = label_map[category]
                        data.y = torch.tensor([label] * data.x.size(0), dtype=torch.long)  # Assign label to all nodes
                        dataset.append(data)
                        labels.append(label)
    print(f"Loaded {len(dataset)} samples from {ast_directory}")
    return dataset, labels


def stratified_split(dataset, labels):
    # Split data into training + validation and test data
    train_val_data, test_data, train_val_labels, test_labels = train_test_split(
        dataset, labels, test_size=0.1, stratify=labels, random_state=42)

    # Split training + validation into actual training and validation data
    train_data, val_data, train_labels, val_labels = train_test_split(
        train_val_data, train_val_labels, test_size=0.11, stratify=train_val_labels,
        random_state=42)  # 0.11 * 0.9 ≈ 0.1

    return train_data, val_data, test_data


class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out[data.train_mask].argmax(dim=1)
        y_true.extend(data.y[data.train_mask].cpu().numpy())
        y_pred.extend(pred.cpu().numpy())
        y_score.extend(F.softmax(out[data.train_mask], dim=1).cpu().detach().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return total_loss / len(loader), accuracy, precision, recall, f1, roc_auc



def evaluate(model, loader, mask_type):
    model.eval()
    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            mask = getattr(data, mask_type)
            pred = out[mask].argmax(dim=1)
            y_true.extend(data.y[mask].cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
            y_score.extend(F.softmax(out[mask], dim=1).cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


def main():
    ast_directory = '../dataset/aisc/ast'
    label_map = generate_label_map(ast_directory)
    dataset, labels = load_data(ast_directory, label_map)

    if len(dataset) == 0:
        print("No data loaded. Please check the dataset directory and files.")
        return

    train_data, val_data, test_data = stratified_split(dataset, labels)

    print(f"Train data size: {len(train_data)}")
    print(f"Validation data size: {len(val_data)}")
    print(f"Test data size: {len(test_data)}")

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)

    num_features = train_data[0].x.shape[1]
    num_classes = len(label_map)

    model = GCN(num_features=num_features, hidden_channels=64, num_classes=num_classes).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(200):
        train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc = train(model, train_loader, optimizer,
                                                                                      criterion)
        val_acc, val_prec, val_rec, val_f1, val_roc_auc = evaluate(model, val_loader, 'val_mask')

        print(
            f'Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, Train Precision: {train_prec:.4f}, Train Recall: {train_rec:.4f}, Train F1 Score: {train_f1:.4f}, Train ROC-AUC: {train_roc_auc:.4f}')
        print(
            f'Val Accuracy: {val_acc:.4f}, Val Precision: {val_prec:.4f}, Val Recall: {val_rec:.4f}, Val F1 Score: {val_f1:.4f}, Val ROC-AUC: {val_roc_auc:.4f}')

    test_acc, test_prec, test_rec, test_f1, test_roc_auc = evaluate(model, test_loader, 'test_mask')
    print(
        f'Test Accuracy: {test_acc:.4f}, Test Precision: {test_prec:.4f}, Test Recall: {test_rec:.4f}, Test F1 Score: {test_f1:.4f}, Test ROC-AUC: {test_roc_auc:.4f}')


if __name__ == "__main__":
    main()


Loaded 2040 samples from ../dataset/aisc/ast
Train data size: 1634
Validation data size: 202
Test data size: 204


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 1, Train Loss: 28.4755, Train Accuracy: 0.2815, Train Precision: 0.2777, Train Recall: 0.2815, Train F1 Score: 0.2792, Train ROC-AUC: 0.5013
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.4947


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 2, Train Loss: 1.9740, Train Accuracy: 0.4930, Train Precision: 0.2614, Train Recall: 0.4930, Train F1 Score: 0.3343, Train ROC-AUC: 0.4957
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.4967


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 3, Train Loss: 1.9264, Train Accuracy: 0.5012, Train Precision: 0.2528, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4916
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.4967


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 4, Train Loss: 1.8871, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4819
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.4969


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 5, Train Loss: 1.8557, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4925
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.4966


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 6, Train Loss: 1.8359, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4940
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5049


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 7, Train Loss: 1.8200, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4876
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 8, Train Loss: 1.8330, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4737
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 9, Train Loss: 1.8361, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4513
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 10, Train Loss: 1.8328, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4479
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 11, Train Loss: 1.8437, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4506
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 12, Train Loss: 1.8539, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4496
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 13, Train Loss: 1.8349, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4485
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 14, Train Loss: 1.8284, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4419
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 15, Train Loss: 1.8490, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4465
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 16, Train Loss: 1.8406, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4528
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 17, Train Loss: 1.8029, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4581
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 18, Train Loss: 1.8371, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4393
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 19, Train Loss: 1.8074, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4467
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 20, Train Loss: 1.8323, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4562
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 21, Train Loss: 1.8501, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4477
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 22, Train Loss: 1.8325, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4539
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 23, Train Loss: 1.8070, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4602
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 24, Train Loss: 1.8316, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4562
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 25, Train Loss: 1.8335, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4577
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 26, Train Loss: 1.8252, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4572
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 27, Train Loss: 1.8361, Train Accuracy: 0.5012, Train Precision: 0.2512, Train Recall: 0.5012, Train F1 Score: 0.3347, Train ROC-AUC: 0.4473
Val Accuracy: 0.5061, Val Precision: 0.2561, Val Recall: 0.5061, Val F1 Score: 0.3401, Val ROC-AUC: 0.5000


KeyboardInterrupt: 