In [7]:
import copy
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Feature Extraction

In [8]:
def hash_feature(value, num_bins=1000):
    """Helper function to hash a value into a fixed number of bins."""
    return int(hashlib.md5(str(value).encode()).hexdigest(), 16) % num_bins


def extract_features(node):
    # Initialize features with default values
    type_feature = [0]
    name_feature = [0]
    value_feature = [0]
    src_feature = [0, 0]
    type_desc_features = [0, 0]
    state_mutability_feature = [0]
    visibility_feature = [0]

    # Extract basic features
    node_type = node.get('nodeType', 'Unknown')
    type_feature = [hash_feature(node_type)]

    # Extract additional features if they exist
    if 'name' in node:
        name_feature = [hash_feature(node.get('name', ''))]
    if 'value' in node:
        value_feature = [hash_feature(node.get('value', ''))]

    # Extract src features (start, end, and length if available)
    if 'src' in node:
        start, length, *_ = map(int, node['src'].split(':'))
        src_feature = [start, length]

    # Extract typeDescriptions features if they exist
    if 'typeDescriptions' in node:
        type_desc = node['typeDescriptions']
        type_desc_features = [
            hash_feature(type_desc.get('typeString', '')),
            hash_feature(type_desc.get('typeIdentifier', ''))
        ]

    # Extract stateMutability if it exists
    if 'stateMutability' in node:
        state_mutability_feature = [hash_feature(node.get('stateMutability', ''))]

    # Extract visibility if it exists
    if 'visibility' in node:
        visibility_feature = [hash_feature(node.get('visibility', ''))]

    # Combine all features into a single feature vector
    features = (type_feature + name_feature + value_feature +
                src_feature + type_desc_features +
                state_mutability_feature + visibility_feature)
    return features

# Data Loading

In [9]:
def add_masks_to_data(data, train_ratio=0.8, val_ratio=0.1):
    num_nodes = data.x.size(0)
    indices = torch.randperm(num_nodes)

    train_size = int(num_nodes * train_ratio)
    val_size = int(num_nodes * val_ratio)

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


def ast_to_graph(ast_json):
    graph = nx.DiGraph()
    node_id = 0

    def add_nodes_edges(node, parent=None):
        nonlocal node_id
        current_node_id = node_id
        graph.add_node(current_node_id, features=extract_features(node))
        if parent is not None:
            graph.add_edge(parent, current_node_id)
        node_id += 1
        for key, value in node.items():
            if isinstance(value, dict):
                add_nodes_edges(value, current_node_id)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        add_nodes_edges(item, current_node_id)

    add_nodes_edges(ast_json)
    edge_index = torch.tensor(list(graph.edges)).t().contiguous()
    x = torch.stack([torch.tensor(graph.nodes[n]['features'], dtype=torch.float) for n in graph.nodes])

    data = Data(x=x, edge_index=edge_index)
    return add_masks_to_data(data)


def generate_label_map(ast_directory):
    label_map = {}
    label_index = 0
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            label_map[category] = label_index
            label_index += 1
    return label_map


def load_data(ast_directory, label_map):
    dataset = []
    labels = []
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):
            for root, _, files in os.walk(category_path):
                for file in files:
                    if file.endswith('.json'):
                        filepath = os.path.join(root, file)
                        with open(filepath, 'r') as f:
                            ast = json.load(f)
                        data = ast_to_graph(ast)
                        label = label_map[category]
                        data.y = torch.tensor([label] * data.x.size(0), dtype=torch.long)  # Assign label to all nodes
                        dataset.append(data)
                        labels.append(label)
    print(f"Loaded {len(dataset)} samples from {ast_directory}")
    return dataset, labels


ast_directory = '../dataset/aisc/ast'
label_map = generate_label_map(ast_directory)
dataset, labels = load_data(ast_directory, label_map)

if len(dataset) == 0:
    print("No data loaded. Please check the dataset directory and files.")

Loaded 1020 samples from ../dataset/aisc/ast


# Data Augmentation

In [10]:
def substitute_nodes(ast, substitutions):
    """
    Substitute certain nodes in the AST with other semantically equivalent nodes.
    :param ast: The AST to be modified.
    :param substitutions: A dictionary where keys are node types to be replaced, and values are the replacements.
    :return: The modified AST.
    """
    if isinstance(ast, dict):
        for key, value in ast.items():
            if key in substitutions:
                ast[key] = substitutions[key]
            else:
                ast[key] = substitute_nodes(value, substitutions)
    elif isinstance(ast, list):
        for i in range(len(ast)):
            ast[i] = substitute_nodes(ast[i], substitutions)
    return ast


def insert_nodes(ast, insertions):
    """
    Insert certain nodes into the AST.
    :param ast: The AST to be modified.
    :param insertions: A dictionary where keys are locations to insert, and values are the nodes to be inserted.
    :return: The modified AST.
    """
    if isinstance(ast, dict):
        for key, value in ast.items():
            if key in insertions:
                ast[key] = [value, insertions[key]] if isinstance(value, list) else [value, insertions[key]]
            else:
                ast[key] = insert_nodes(value, insertions)
    elif isinstance(ast, list):
        for i in range(len(ast)):
            ast[i] = insert_nodes(ast[i], insertions)
    return ast


def delete_nodes(ast, deletions):
    """
    Delete certain nodes from the AST.
    :param ast: The AST to be modified.
    :param deletions: A set of node types to be deleted.
    :return: The modified AST.
    """
    if isinstance(ast, dict):
        keys_to_delete = [key for key in ast if key in deletions]
        for key in keys_to_delete:
            del ast[key]
        for key, value in ast.items():
            ast[key] = delete_nodes(value, deletions)
    elif isinstance(ast, list):
        ast = [delete_nodes(item, deletions) for item in ast if item not in deletions]
    return ast


def rename_identifiers(ast, renames):
    """
    Rename variables/functions in the AST.
    :param ast: The AST to be modified.
    :param renames: A dictionary where keys are the original names and values are the new names.
    :return: The modified AST.
    """
    if isinstance(ast, dict):
        for key, value in ast.items():
            if key == 'name' and value in renames:
                ast[key] = renames[value]
            else:
                ast[key] = rename_identifiers(value, renames)
    elif isinstance(ast, list):
        for i in range(len(ast)):
            ast[i] = rename_identifiers(ast[i], renames)
    return ast


def reorder_statements(ast):
    if isinstance(ast, dict) and 'body' in ast:
        if isinstance(ast['body'], list):
            random.shuffle(ast['body'])
        else:
            reorder_statements(ast['body'])
    elif isinstance(ast, list):
        for item in ast:
            reorder_statements(item)
    return ast


def add_no_op_statements(ast):
    no_op_statement = {'nodeType': 'ExpressionStatement', 'expression': {'nodeType': 'Literal', 'value': '0'}}
    if isinstance(ast, dict) and 'body' in ast:
        if isinstance(ast['body'], list):
            ast['body'].append(no_op_statement)
        else:
            add_no_op_statements(ast['body'])
    elif isinstance(ast, list):
        for item in ast:
            add_no_op_statements(item)
    return ast


def apply_augmentation(ast):
    # Define your augmentation strategies
    substitutions = {'FunctionDefinition': 'ModifierDefinition'}
    insertions = {'body': {'nodeType': 'ExpressionStatement', 'expression': {'nodeType': 'Literal', 'value': '0'}}}
    deletions = {'ModifierDefinition'}
    renames = {'oldVarName': 'newVarName', 'oldFuncName': 'newFuncName'}

    # Apply augmentations randomly
    if random.random() > 0.5:
        ast = substitute_nodes(ast, substitutions)
    if random.random() > 0.5:
        ast = insert_nodes(ast, insertions)
    if random.random() > 0.5:
        ast = delete_nodes(ast, deletions)
    if random.random() > 0.5:
        ast = rename_identifiers(ast, renames)
    if random.random() > 0.5:
        ast = reorder_statements(ast)
    if random.random() > 0.5:
        ast = add_no_op_statements(ast)

    return ast


def generate_augmented_asts(dataset, num_augmentations=5):
    augmented_dataset = []
    for ast in dataset:
        augmented_dataset.append(ast)
        for _ in range(num_augmentations):
            augmented_ast = apply_augmentation(copy.deepcopy(ast))
            augmented_dataset.append(augmented_ast)
    return augmented_dataset



# Graph Neural Network

In [11]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


class AugmentedASTDataset(Dataset):
    def __init__(self, dataset, apply_augmentation_func, num_augmentations=5, train=True):
        self.dataset = dataset
        self.apply_augmentation_func = apply_augmentation_func
        self.num_augmentations = num_augmentations
        self.train = train

    def __len__(self):
        return len(self.dataset) * (self.num_augmentations if self.train else 1)

    def __getitem__(self, idx):
        original_idx = idx % len(self.dataset)
        ast = self.dataset[original_idx]

        if self.train:
            augmented_ast = self.apply_augmentation_func(copy.deepcopy(ast))
            return augmented_ast
        else:
            return ast


def prepare_augmented_dataloader(dataset, apply_augmentation_func, batch_size=32, num_augmentations=5, shuffle=True):
    augmented_dataset = AugmentedASTDataset(dataset, apply_augmentation_func, num_augmentations=num_augmentations,
                                            train=True)
    return DataLoader(augmented_dataset, batch_size=batch_size, shuffle=shuffle)


def prepare_dataloader(dataset, batch_size=32, shuffle=True):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


def stratified_split(dataset, labels):
    # Split data into training + validation and test data
    train_val_data, test_data, train_val_labels, test_labels = train_test_split(
        dataset, labels, test_size=0.1, stratify=labels, random_state=42)

    # Split training + validation into actual training and validation data
    train_data, val_data, train_labels, val_labels = train_test_split(
        train_val_data, train_val_labels, test_size=0.11, stratify=train_val_labels,
        random_state=42)  # 0.11 * 0.9 ≈ 0.1

    return train_data, val_data, test_data


def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out[data.train_mask].argmax(dim=1)
        y_true.extend(data.y[data.train_mask].cpu().numpy())
        y_pred.extend(pred.cpu().numpy())
        y_score.extend(F.softmax(out[data.train_mask], dim=1).cpu().detach().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return total_loss / len(loader), accuracy, precision, recall, f1, roc_auc


def evaluate(model, loader, mask_type):
    model.eval()

    y_true = []
    y_pred = []
    y_score = []

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            mask = getattr(data, mask_type)
            pred = out[mask].argmax(dim=1)
            y_true.extend(data.y[mask].cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
            y_score.extend(F.softmax(out[mask], dim=1).cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_score = np.array(y_score)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_score, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


In [12]:
train_data, val_data, test_data = stratified_split(dataset, labels)

print(f"Train data size ........ : {len(train_data)}")
print(f"Validation data size ... : {len(val_data)}")
print(f"Test data size ......... : {len(test_data)}")

train_loader = prepare_augmented_dataloader(train_data, apply_augmentation, batch_size=32, num_augmentations=5)
val_loader = prepare_dataloader(val_data, batch_size=32)
test_loader = prepare_dataloader(test_data, batch_size=32)

num_features = train_data[0].x.shape[1]
num_classes = len(label_map)

model = GCN(num_features=num_features, hidden_channels=64, num_classes=num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1500):
    train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc = train(model, train_loader, optimizer,
                                                                                  criterion)
    val_acc, val_prec, val_rec, val_f1, val_roc_auc = evaluate(model, val_loader, 'val_mask')

    print('-----------------------------------------------------------------------------------------------------')
    print(f'EPOCH: {epoch + 1} -> Loss: {train_loss:.4f}')
    print(
        f'(Train) Accuracy: {train_acc:.4f}, Precision: {train_prec:.4f}, Recall: {train_rec:.4f}, F1: {train_f1:.4f}, AUC: {train_roc_auc:.4f}')
    print(
        f'(Valid) Accuracy: {val_acc:.4f}, Precision: {val_prec:.4f}, Recall: {val_rec:.4f}, F1: {val_f1:.4f}, AUC: {val_roc_auc:.4f}')
    print('-----------------------------------------------------------------------------------------------------')

print('-----------------------------------------------------------------------------------------------------')
test_acc, test_prec, test_rec, test_f1, test_roc_auc = evaluate(model, test_loader, 'test_mask')
print(
    f'(Test) Accuracy: {test_acc:.4f}, Precision: {test_prec:.4f}, Recall: {test_rec:.4f}, F1: {test_f1:.4f}, AUC: {test_roc_auc:.4f}')


Train data size ........ : 817
Validation data size ... : 101
Test data size ......... : 102


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


-----------------------------------------------------------------------------------------------------
EPOCH: 1 -> Loss: 22.8029
(Train) Accuracy: 0.1410, Precision: 0.1227, Recall: 0.1410, F1: 0.0975, AUC: 0.5065
(Valid) Accuracy: 0.1907, Precision: 0.1220, Recall: 0.1907, F1: 0.0649, AUC: 0.5105
-----------------------------------------------------------------------------------------------------


KeyboardInterrupt: 

# Traditional Models

In [13]:
def pad_features_to_fixed_length(features, fixed_length=4716):
    """Pad feature vectors to a fixed length."""
    padding_length = fixed_length - len(features)
    if padding_length > 0:
        return np.concatenate([features, np.zeros(padding_length)])
    else:
        return features[:fixed_length]


def extract_graph_features(graph, fixed_length=4716):
    node_features = [graph.nodes[n]['features'] for n in graph.nodes]
    max_length = max(len(f) for f in node_features)
    padded_features = [pad_features_to_fixed_length(f, max_length) for f in node_features]
    flat_features = np.array(padded_features).flatten()
    flat_features = pad_features_to_fixed_length(flat_features, fixed_length)
    feature_to_node_map = []
    for node_id in range(len(graph.nodes)):
        node_start_idx = node_id * max_length
        node_end_idx = node_start_idx + max_length
        feature_to_node_map.append((node_start_idx, node_end_idx, node_id))
    return flat_features, feature_to_node_map


def prepare_dataset(dataset, fixed_length=4716):
    features = []
    labels = []
    feature_node_mappings = []
    for data in dataset:
        graph = nx.DiGraph()
        for node_id in range(data.x.size(0)):
            graph.add_node(node_id, features=data.x[node_id].numpy())
        graph_features, feature_to_node_map = extract_graph_features(graph, fixed_length)
        features.append(graph_features)
        feature_node_mappings.append(feature_to_node_map)
        if data.y.numel() == 1:
            labels.append(data.y.item())
        else:
            labels.append(data.y[0].item())
    return np.array(features), np.array(labels), feature_node_mappings


def evaluate_classifier(clf, X_train, y_train, X_test, y_test):
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    y_score = clf.predict_proba(X_test) if hasattr(clf, "predict_proba") else clf.decision_function(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
    roc_auc = roc_auc_score(y_test, y_score, average='weighted', multi_class='ovo') if hasattr(clf,
                                                                                               "predict_proba") else None
    return accuracy, precision, recall, f1, roc_auc

In [15]:
# Prepare dataset with a fixed length of 4716 features
fixed_length = 4716
features, labels, _ = prepare_dataset(dataset, fixed_length)

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
classifiers = {
    'RandomForest': RandomForestClassifier(n_estimators=100, random_state=42),
    'SVM': SVC(probability=True, random_state=42),
    'DecisionTree': DecisionTreeClassifier(random_state=42),
    'GaussianNB': GaussianNB(),
    'GradientBoosting': GradientBoostingClassifier(n_estimators=100, random_state=42)
}
metrics = {
    'RandomForest': {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []},
    'SVM': {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []},
    'DecisionTree': {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []},
    'GaussianNB': {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []},
    'GradientBoosting': {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'roc_auc': []}
}

for name, clf in classifiers.items():

    print(f"*** {name} ***\n")

    for fold, (train_index, test_index) in enumerate(skf.split(features, labels)):
        X_train, X_test = features[train_index], features[test_index]
        y_train, y_test = labels[train_index], labels[test_index]

        # Generate augmented training data
        train_dataset = [dataset[i] for i in train_index]
        augmented_train_dataset = generate_augmented_asts(train_dataset, num_augmentations=5)
        X_train_augmented, y_train_augmented, _ = prepare_dataset(augmented_train_dataset, fixed_length)

        acc, prec, rec, f1, roc_auc = evaluate_classifier(clf, X_train_augmented, y_train_augmented, X_test, y_test)
        metrics[name]['accuracy'].append(acc)
        metrics[name]['precision'].append(prec)
        metrics[name]['recall'].append(rec)
        metrics[name]['f1'].append(f1)
        metrics[name]['roc_auc'].append(roc_auc)

        print(
            f"Fold {fold + 1} - Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}, AUC: {roc_auc:.4f}")

    print("........................................................\n")

print("--------------------------------------------------------")

for name in classifiers.keys():
    avg_accuracy = np.mean(metrics[name]['accuracy'])
    avg_precision = np.mean(metrics[name]['precision'])
    avg_recall = np.mean(metrics[name]['recall'])
    avg_f1 = np.mean(metrics[name]['f1'])
    avg_roc_auc = np.mean([x for x in metrics[name]['roc_auc'] if x is not None])

    print(f"\n{name} Average Metrics over 10 folds:")
    print(f"Accuracy .... : {avg_accuracy:.4f}")
    print(f"Precision ... : {avg_precision:.4f}")
    print(f"Recall ...... : {avg_recall:.4f}")
    print(f"F1 Score .... : {avg_f1:.4f}")
    print(f"ROC-AUC ..... : {avg_roc_auc:.4f}")


*** RandomForest ***

Fold 1 - Accuracy: 0.9216, Precision: 0.9383, Recall: 0.9216, F1: 0.9227, AUC: 0.9959
Fold 2 - Accuracy: 0.9216, Precision: 0.9264, Recall: 0.9216, F1: 0.9213, AUC: 0.9973
Fold 3 - Accuracy: 0.9118, Precision: 0.9193, Recall: 0.9118, F1: 0.9113, AUC: 0.9955


KeyboardInterrupt: 

# Explainability

In [21]:
import json
import numpy as np
import networkx as nx
from sklearn.ensemble import RandomForestClassifier
import hashlib

global_node_id = 0


def reset_global_node_id():
    global global_node_id
    global_node_id = 0


def assign_unique_id(node):
    global global_node_id
    node['_id'] = global_node_id
    global_node_id += 1


def traverse_and_assign_ids(node):
    if isinstance(node, dict):
        assign_unique_id(node)
        for key, value in node.items():
            traverse_and_assign_ids(value)
    elif isinstance(node, list):
        for item in node:
            traverse_and_assign_ids(item)


def load_data(ast_directory, label_map):
    dataset = []
    labels = []
    for root, dirs, files in os.walk(ast_directory):
        for file in files:
            if file.endswith('.json'):
                filepath = os.path.join(root, file)
                with open(filepath, 'r') as f:
                    ast = json.load(f)
                reset_global_node_id()
                traverse_and_assign_ids(ast)
                dataset.append(ast)
                label_folder = root.split(os.sep)[-1]
                label = label_map[label_folder]
                labels.append(label)
    return dataset, labels


def hash_feature(value, num_bins=1000):
    return int(hashlib.md5(str(value).encode()).hexdigest(), 16) % num_bins


def extract_features(node):
    node_type = node.get('nodeType', 'Unknown')
    type_feature = [hash_feature(node_type)]
    name_feature = [hash_feature(node.get('name', ''))] if 'name' in node else [0]
    value_feature = [hash_feature(node.get('value', ''))] if 'value' in node else [0]
    src_feature = [0, 0]
    if 'src' in node:
        start, length, *_ = map(int, node['src'].split(':'))
        src_feature = [start, length]
    type_desc_features = []
    if 'typeDescriptions' in node:
        type_desc = node['typeDescriptions']
        type_desc_features.append(hash_feature(type_desc.get('typeString', '')))
        type_desc_features.append(hash_feature(type_desc.get('typeIdentifier', '')))
    state_mutability_feature = [hash_feature(node.get('stateMutability', ''))] if 'stateMutability' in node else [0]
    visibility_feature = [hash_feature(node.get('visibility', ''))] if 'visibility' in node else [0]
    features = type_feature + name_feature + value_feature + src_feature + type_desc_features + state_mutability_feature + visibility_feature
    return features


def pad_features_to_fixed_length(features, fixed_length=4716):
    padding_length = fixed_length - len(features)
    if padding_length > 0:
        return np.concatenate([features, np.zeros(padding_length)])
    else:
        return features[:fixed_length]


def extract_graph_features_with_mapping(ast, fixed_length=4716):
    graph = nx.DiGraph()
    node_features = []
    feature_to_node_map = []

    for node in ast:
        if isinstance(node, dict) and '_id' in node:
            node_id = node['_id']
            features = extract_features(node)
            graph.add_node(node_id, features=features)
            node_features.append(features)
            feature_to_node_map.append((len(node_features) - 1, len(node_features), node_id))

    max_length = max(len(f) for f in node_features)
    padded_features = [pad_features_to_fixed_length(f, max_length) for f in node_features]
    flat_features = np.array(padded_features).flatten()
    flat_features = pad_features_to_fixed_length(flat_features, fixed_length)

    return flat_features, feature_to_node_map


def prepare_dataset_with_mapping(dataset, fixed_length=4716):
    features = []
    labels = []
    feature_node_mappings = []

    for ast in dataset:
        graph_features, feature_to_node_map = extract_graph_features_with_mapping(ast, fixed_length)
        features.append(graph_features)
        feature_node_mappings.append(feature_to_node_map)
        labels.append(ast[0]['y'])

    return np.array(features), np.array(labels), feature_node_mappings


def get_important_nodes(feature_importances, feature_to_node_map, num_top_features=10):
    important_features_indices = np.argsort(feature_importances)[-num_top_features:]
    important_nodes = set()
    for feature_idx in important_features_indices:
        for start_idx, end_idx, node_id in feature_to_node_map:
            if start_idx <= feature_idx < end_idx:
                important_nodes.add(node_id)
                break
    return important_nodes


def extract_lines_of_code_from_nodes(ast, important_nodes):
    lines_of_code = set()

    def traverse_ast(node):
        if isinstance(node, dict):
            node_id = node['_id']
            if node_id in important_nodes:
                if 'src' in node:
                    start, length, *_ = map(int, node['src'].split(':'))
                    lines_of_code.add((start, length))
            for key, value in node.items():
                traverse_ast(value)
        elif isinstance(node, list):
            for item in node:
                traverse_ast(item)

    traverse_ast(ast)
    return lines_of_code


def highlight_important_lines(ast, feature_importances, feature_to_node_map, num_top_features=10):
    important_nodes = get_important_nodes(feature_importances, feature_to_node_map, num_top_features)
    important_lines = extract_lines_of_code_from_nodes(ast, important_nodes)
    return important_lines


def get_feature_importance(clf):
    if hasattr(clf, 'feature_importances_'):
        return clf.feature_importances_
    else:
        raise ValueError(f"Model of type {type(clf)} does not support feature importances.")


In [22]:
def main():
    ast_directory = '../dataset/aisc/ast'
    label_map = generate_label_map(ast_directory)
    dataset, labels = load_data(ast_directory, label_map)

    if len(dataset) == 0:
        print("No data loaded. Please check the dataset directory and files.")
        return

    # Prepare dataset with a fixed length of 4716 features and get feature-to-node mappings
    fixed_length = 4716
    features, labels, feature_node_mappings = prepare_dataset_with_mapping(dataset, fixed_length)

    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(features, labels)

    feature_importances = get_feature_importance(clf)

    important_lines_per_ast = []
    for i, ast in enumerate(dataset):
        important_lines = highlight_important_lines(ast, feature_importances, feature_node_mappings[i])
        important_lines_per_ast.append(important_lines)

    for i, important_lines in enumerate(important_lines_per_ast):
        print(f"AST {i} Important lines: {important_lines}")


if __name__ == "__main__":
    main()


ValueError: max() arg is an empty sequence