In [316]:
import json
import os

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, SAGEConv
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected, add_self_loops
from transformers import AutoModel, AutoConfig

In [317]:
"""
def extract_features(node, depth=0):
    # List of node types to one-hot encode
    node_types = [
        'FunctionDefinition', 'VariableDeclaration', 'Literal', 'ExpressionStatement',
        'IfStatement', 'ForStatement', 'WhileStatement', 'ReturnStatement', 'Block',
        'Assignment', 'BinaryOperation', 'UnaryOperation', 'ParameterList',
        'FunctionCall', 'Identifier', 'IndexAccess', 'MemberAccess', 'Unknown'
    ]

    # One-hot encoding for node type
    node_type = node.get('nodeType', 'Unknown')
    type_vector = [1 if node_type == t else 0 for t in node_types]

    # Depth feature
    depth_feature = [depth]

    # Name feature (hashed)
    name_feature = [hash(node.get('name', '')) % 1000] if 'name' in node else [0]

    # Value feature (hashed)
    value = node.get('value', None)
    if isinstance(value, dict):
        value_feature = [hash(str(value)) % 1000]
    elif value is not None:
        value_feature = [hash(str(value)) % 1000]
    else:
        value_feature = [0]

    # Children count
    children_count = len(node.get('children', []))
    children_feature = [children_count]

    # Position features (line and column, if available)
    position = node.get('src', None)
    if position:
        line, column = map(int, position.split(':')[:2])
        position_feature = [line, column]
    else:
        position_feature = [0, 0]

    # Type descriptions (if available)
    type_descriptions = node.get('typeDescriptions', {})
    type_descriptions_feature = [hash(str(type_descriptions)) % 1000]

    # Other custom properties
    other_properties = ['visibility', 'stateMutability', 'constant', 'payable']
    other_features = [hash(node.get(prop, '')) % 1000 if prop in node else 0 for prop in other_properties]

    # Combine all features
    return type_vector + depth_feature + name_feature + value_feature + children_feature + position_feature + type_descriptions_feature + other_features


def ast_to_graph(ast_json):
    graph = nx.DiGraph()
    node_id = 0

    def add_nodes_edges(node, parent=None, depth=0):
        nonlocal node_id
        current_node_id = node_id
        graph.add_node(current_node_id, features=extract_features(node, depth))
        if parent is not None:
            graph.add_edge(parent, current_node_id)
        node_id += 1
        for key, value in node.items():
            if isinstance(value, dict):
                add_nodes_edges(value, current_node_id, depth + 1)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        add_nodes_edges(item, current_node_id, depth + 1)

    add_nodes_edges(ast_json)
    edge_index = torch.tensor(list(graph.edges)).t().contiguous()
    x = torch.stack([torch.tensor(graph.nodes[n]['features'], dtype=torch.float) for n in graph.nodes])

    data = Data(x=x, edge_index=edge_index)
    return add_masks_to_data(data)

"""


def extract_features(node):
    node_type = node.get('nodeType', 'Unknown')
    return [hash(node_type) % 1000]  # Simple feature vector from node type


def ast_to_pyg_data(ast_json):
    nodes = []
    edges = []
    node_id = 0
    node_map = {}

    def add_nodes_edges(node, parent=None):
        nonlocal node_id
        current_node_id = node_id
        node_map[id(node)] = current_node_id
        nodes.append(extract_features(node))
        if parent is not None:
            edges.append((node_map[id(parent)], current_node_id))
        node_id += 1
        for key, value in node.items():
            if isinstance(value, dict):
                add_nodes_edges(value, node)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        add_nodes_edges(item, node)

    add_nodes_edges(ast_json)
    x = torch.tensor(nodes, dtype=torch.float)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_index, _ = add_self_loops(edge_index)
    edge_index = to_undirected(edge_index)
    return Data(x=x, edge_index=edge_index)


def add_masks_to_data(data, train_ratio=0.8, val_ratio=0.1):
    num_nodes = data.x.size(0)
    indices = torch.randperm(num_nodes)

    train_size = int(num_nodes * train_ratio)
    val_size = int(num_nodes * val_ratio)

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


def generate_label_map(ast_directory):
    label_map = {}
    label_index = 0
    for category in os.listdir(ast_directory):
        category_path = os.path.join(ast_directory, category)
        if os.path.isdir(category_path):  # Ensure it's a directory
            label_map[category] = label_index
            label_index += 1
    return label_map


def load_data(ast_directory, label_map):
    dataset, labels = [], []
    for root, dirs, files in os.walk(ast_directory):
        for file in files:
            if file.endswith('.json'):
                filepath = os.path.join(root, file)
                with open(filepath, 'r') as f:
                    ast = json.load(f)
                data = ast_to_pyg_data(ast)
                label_folder = os.path.basename(root)
                label = label_map[label_folder]
                data.y = torch.tensor([label] * data.x.size(0), dtype=torch.long)  # Assign label to all nodes
                dataset.append(data)
                labels.append(label)
    return dataset, labels


In [318]:
ast_directory = '../dataset/aisc/ast'
label_map = generate_label_map(ast_directory)

# Load all data
dataset, labels = load_data(ast_directory, label_map)

In [319]:
def stratified_split(dataset, labels):
    # Split data into training + validation and test data
    train_val_data, test_data, train_val_labels, test_labels = train_test_split(
        dataset, labels, test_size=0.1, stratify=labels, random_state=42)

    # Split training + validation into actual training and validation data
    train_data, val_data, train_labels, val_labels = train_test_split(
        train_val_data, train_val_labels, test_size=0.11, stratify=train_val_labels,
        random_state=42)  # 0.11 * 0.9 ≈ 0.1

    return train_data, val_data, test_data

In [320]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


class GAT(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, num_heads=1):
        super(GAT, self).__init__()
        self.conv1 = GATConv(num_node_features, hidden_channels, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * num_heads, num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, hidden_feats, 'mean')
        self.conv2 = SAGEConv(hidden_feats, out_feats, 'mean')

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    y_true = []
    y_pred = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = out[data.train_mask].argmax(dim=1)
        y_true.extend(data.y[data.train_mask].cpu().numpy())
        y_pred.extend(pred.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = (y_true == y_pred).mean()
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_pred, average='weighted', multi_class='ovo')

    return total_loss / len(loader), accuracy, precision, recall, f1, roc_auc


def evaluate(model, loader, mask_type):
    model.eval()
    y_true = []
    y_pred = []

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            mask = getattr(data, mask_type)
            pred = out[mask].argmax(dim=1)
            y_true.extend(data.y[mask].cpu().numpy())
            y_pred.extend(pred.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = (y_true == y_pred).mean()
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_pred, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


In [321]:
def prepare_graphormer_inputs(data, num_heads):
    # Add batch dimension and ensure types
    input_nodes = data.x.unsqueeze(0).long()  # shape: [1, num_nodes, feature_size]
    input_edges = data.edge_index.unsqueeze(0).long()  # shape: [1, 2, num_edges]

    num_nodes = data.num_nodes

    # Initialize attention bias for Graphormer
    attn_bias = torch.zeros((1, num_heads, num_nodes + 1, num_nodes + 1),
                            dtype=torch.float)  # shape: [1, num_heads, num_nodes+1, num_nodes+1]

    # Degree tensors with batch dimension
    in_degree = torch.zeros((1, num_nodes + 1), dtype=torch.long)
    out_degree = torch.zeros((1, num_nodes + 1), dtype=torch.long)
    for node in data.edge_index[1]:
        in_degree[0][node + 1] += 1  # +1 to accommodate the CLS token
    for node in data.edge_index[0]:
        out_degree[0][node + 1] += 1  # +1 to accommodate the CLS token

    # Spatial position tensor, initialized to zero
    spatial_pos = torch.zeros((1, num_nodes + 1, num_nodes + 1), dtype=torch.long)

    # Attention edge type tensor
    attn_edge_type = torch.zeros((1, input_edges.size(2)), dtype=torch.long)

    # Debug prints
    print(f"input_nodes shape: {input_nodes.shape}, dtype: {input_nodes.dtype}")
    print(f"input_edges shape: {input_edges.shape}, dtype: {input_edges.dtype}")
    print(f"attn_bias shape: {attn_bias.shape}, dtype: {attn_bias.dtype}")
    print(f"in_degree shape: {in_degree.shape}, dtype: {in_degree.dtype}")
    print(f"out_degree shape: {out_degree.shape}, dtype: {out_degree.dtype}")
    print(f"spatial_pos shape: {spatial_pos.shape}, dtype: {spatial_pos.dtype}")
    print(f"attn_edge_type shape: {attn_edge_type.shape}, dtype: {attn_edge_type.dtype}")

    return {
        'input_nodes': input_nodes,
        'input_edges': input_edges,
        'attn_bias': attn_bias,
        'in_degree': in_degree,
        'out_degree': out_degree,
        'spatial_pos': spatial_pos,
        'attn_edge_type': attn_edge_type
    }


In [322]:
def train(model, loader, optimizer, criterion, num_heads):
    model.train()
    total_loss = 0
    y_true = []
    y_pred = []

    for data in loader:
        data = data.to(device)
        inputs = prepare_graphormer_inputs(data, num_heads)
        optimizer.zero_grad()
        outputs = model(**inputs)
        print(f"outputs logits shape: {outputs.logits.shape}")
        print(f"data.y shape: {data.y.shape}, dtype: {data.y.dtype}")
        loss = criterion(outputs.logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        pred = outputs.logits.argmax(dim=1)
        y_true.extend(data.y.cpu().numpy())
        y_pred.extend(pred.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = (y_true == y_pred).mean()
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_pred, average='weighted', multi_class='ovo')

    return total_loss / len(loader), accuracy, precision, recall, f1, roc_auc


def evaluate(model, loader, num_heads):
    model.eval()
    y_true = []
    y_pred = []

    for data in loader:
        data = data.to(device)
        inputs = prepare_graphormer_inputs(data, num_heads)
        with torch.no_grad():
            outputs = model(**inputs)

        pred = outputs.logits.argmax(dim=1)
        y_true.extend(data.y.cpu().numpy())
        y_pred.extend(pred.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = (y_true == y_pred).mean()
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true, y_pred, average='weighted', multi_class='ovo')

    return accuracy, precision, recall, f1, roc_auc


In [323]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Split data
train_data, val_data, test_data = stratified_split(dataset, labels)

# Convert to DataLoader
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)

num_features = len(extract_features({}))  # Determine number of features dynamically
num_features = 1
num_classes = len(label_map)  # Number of classes from the label map

# model = GCN(num_features=num_features, hidden_channels=16, num_classes=num_classes).to(device)
# model = GAT(num_node_features=num_features, hidden_channels=16, num_classes=num_classes, num_heads=4).to(device)
model_name = "clefourrier/graphormer-base-pcqm4mv2"
config = AutoConfig.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, config=config).to(device)

# Extract the number of attention heads from the configuration
num_heads = config.num_attention_heads

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(200):
    train_loss, train_acc, train_prec, train_rec, train_f1, train_roc_auc = train(model, train_loader, optimizer,
                                                                                  criterion, num_heads)
    val_acc, val_prec, val_rec, val_f1, val_roc_auc = evaluate(model, val_loader, num_heads)

    print(
        f'Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, Train Precision: {train_prec:.4f}, Train Recall: {train_rec:.4f}, Train F1 Score: {train_f1:.4f}, Train ROC-AUC: {train_roc_auc:.4f}')
    print(
        f'Val Accuracy: {val_acc:.4f}, Val Precision: {val_prec:.4f}, Val Recall: {val_rec:.4f}, Val F1 Score: {val_f1:.4f}, Val ROC-AUC: {val_roc_auc:.4f}')

test_acc, test_prec, test_rec, test_f1, test_roc_auc = evaluate(model, test_loader, num_heads)
print(
    f'Test Accuracy: {test_acc:.4f}, Test Precision: {test_prec:.4f}, Test Recall: {test_rec:.4f}, Test F1 Score: {test_f1:.4f}, Test ROC-AUC: {test_roc_auc:.4f}')

"""
for epoch in range(200):
    train_loss = train(model, train_loader, optimizer, criterion)
    accuracy, precision, recall, f1, roc_auc = evaluate(model, val_loader, 'val_mask')
    print(
        f'Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Val Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, ROC-AUC: {roc_auc:.4f}')

accuracy, precision, recall, f1, roc_auc = evaluate(model, test_loader, 'test_mask')
print(
    f'Test Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, ROC-AUC: {roc_auc:.4f}')

"""

Some weights of GraphormerModel were not initialized from the model checkpoint at clefourrier/graphormer-base-pcqm4mv2 and are newly initialized: ['graph_encoder.emb_layer_norm.bias', 'graph_encoder.emb_layer_norm.weight', 'graph_encoder.graph_attn_bias.edge_dis_encoder.weight', 'graph_encoder.graph_attn_bias.edge_encoder.weight', 'graph_encoder.graph_attn_bias.graph_token_virtual_distance.weight', 'graph_encoder.graph_attn_bias.spatial_pos_encoder.weight', 'graph_encoder.graph_node_feature.atom_encoder.weight', 'graph_encoder.graph_node_feature.graph_token.weight', 'graph_encoder.graph_node_feature.in_degree_encoder.weight', 'graph_encoder.graph_node_feature.out_degree_encoder.weight', 'graph_encoder.layers.0.fc1.bias', 'graph_encoder.layers.0.fc1.weight', 'graph_encoder.layers.0.fc2.bias', 'graph_encoder.layers.0.fc2.weight', 'graph_encoder.layers.0.final_layer_norm.bias', 'graph_encoder.layers.0.final_layer_norm.weight', 'graph_encoder.layers.0.self_attn.k_proj.bias', 'graph_encoder

input_nodes shape: torch.Size([1, 2192, 1]), dtype: torch.int64
input_edges shape: torch.Size([1, 2, 6544]), dtype: torch.int64
attn_bias shape: torch.Size([1, 32, 2193, 2193]), dtype: torch.float32
in_degree shape: torch.Size([1, 2193]), dtype: torch.int64
out_degree shape: torch.Size([1, 2193]), dtype: torch.int64
spatial_pos shape: torch.Size([1, 2193, 2193]), dtype: torch.int64
attn_edge_type shape: torch.Size([1, 6544]), dtype: torch.int64


RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor