# ANN with AST

In [121]:
from collections import Counter

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import tree_sitter_python as tspython
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tree_sitter import Language, Parser

PY_LANGUAGE = Language(tspython.language())
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

## Utils

In [120]:
def set_seed(seed: int = 420):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def calculate_metrics_classifier(true_label, predicted):
    f1 = f1_score(true_label, predicted)
    roc_auc = roc_auc_score(true_label, predicted)
    recall = recall_score(true_label, predicted)
    accuracy = accuracy_score(true_label, predicted)
    precision = precision_score(true_label, predicted)

    print(
        f"F1: {f1} | ROC/AUC: {roc_auc} | RECALL: {recall} | PRECISION: {precision} | ACCURACY: {accuracy}"
    )
    return recall

## Data Loading

In [93]:
set_seed()
df = pd.read_csv("../../data/generated/dataset.csv")

print(f"Total size: {len(df)}\n")

train_val_df, test_df = train_test_split(
    df, test_size=0.2, stratify=df["generated"], random_state=420
)
train_df, val_df = train_test_split(
    df, test_size=0.22, stratify=df["generated"], random_state=420
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

Total size: 12428

Train size: 9693
Validation size: 2735
Test size: 2486


Check that the data is balanced

In [94]:
print(f"Mean generated (train): {train_df['generated'].mean()}")
print(f"Mean generated (val): {val_df['generated'].mean()}")
print(f"Mean generated (test): {test_df['generated'].mean()}")

Mean generated (train): 0.5212008666047663
Mean generated (val): 0.5210237659963437
Mean generated (test): 0.5213193885760258


## Dataset building

In [95]:
parser = Parser(PY_LANGUAGE)
node_types = set()


def walk_tree(node, types):
    types.append(node.type)
    for child in node.children:
        walk_tree(child, types)


def code_to_feature_vector(code):
    code = code.encode("utf-8")
    tree = parser.parse(code)
    types = []
    walk_tree(tree.root_node, types)
    counts = Counter(types)
    feature_vector = [counts.get(typ, 0) for typ in node_types]

    return feature_vector


# Gather all node types
for _, row in train_df.iterrows():
    tree = parser.parse(str.encode(row["code"]))
    types = []
    walk_tree(tree.root_node, types)
    node_types.update(types)

node_types = sorted(node_types)
type_to_idx = {typ: i for i, typ in enumerate(node_types)}

Save node types for inference

In [96]:
with open("../../data/ast/node_types.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(node_types))

In [104]:
def ast_encoding(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    features_df = pd.DataFrame(df["code"].apply(code_to_feature_vector).apply(pd.Series))
    features_df.columns = list(node_types)
    return features_df, df["generated"]

In [106]:
train_features, train_target = ast_encoding(train_df)
val_features, val_target = ast_encoding(val_df)
test_features, test_target = ast_encoding(test_df)

In [111]:
class VectorDataset(Dataset):
    def __init__(self, features: pd.DataFrame, labels: pd.DataFrame):
        self.features = torch.FloatTensor(features.to_numpy())
        self.labels = torch.LongTensor(labels.to_numpy())

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


train_dataset = VectorDataset(train_features, train_target)
val_dataset = VectorDataset(val_features, val_target)
test_dataset = VectorDataset(test_features, test_target)

In [112]:
train_batch_size = 32
eval_batch_size = 128

train_loader = DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    generator=torch.Generator().manual_seed(420),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=eval_batch_size,
    shuffle=False,
)

test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False)

## Model

In [116]:
class BinaryClassifier(nn.Module):
    def __init__(self, input_dim):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return self.sigmoid(x)

## Training

In [129]:
def train_model(
    train_loader: DataLoader, model: nn.Module, optimizer: optim.Optimizer, criterion
):
    model.train()
    train_losses = []
    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.flatten().float(), labels.float())
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    return np.mean(train_losses)


def eval_model(eval_loader: DataLoader, model: nn.Module, criterion):
    model.eval()

    eval_losses = []
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in tqdm(eval_loader, desc="Evaluating"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            outputs = model(inputs)
            eval_losses.append(
                criterion(outputs.flatten().float(), labels.float()).item()
            )
            preds = (outputs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()
    recall = calculate_metrics_classifier(all_labels, all_preds)
    return np.mean(eval_losses), recall

In [None]:
model = BinaryClassifier(input_dim=len(node_types)).to(DEVICE)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


best_recall = 0
epochs = 100
early_stopping = 5
es_index = 0
for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}")
    train_loss = train_model(train_loader, model, optimizer, criterion)
    print(f"Mean train loss: {train_loss:.3f}")
    eval_loss, recall = eval_model(val_loader, model, criterion)
    print(f"Mean val loss: {eval_loss:.3f}")

    if recall > best_recall:
        best_recall = recall
        torch.save(model.state_dict(), "../../data/generated/best_nn_ast.pth")
        es_index = 0

    else:
        es_index += 1

    if es_index > early_stopping:
        break


Epoch 1


Training: 100%|██████████| 303/303 [00:03<00:00, 78.36it/s] 


Mean train loss: 0.666


Evaluating: 100%|██████████| 22/22 [00:00<00:00, 67.80it/s]


F1: 0.6265249901613538 | ROC/AUC: 0.6571608410338825 | RECALL: 0.5585964912280702 | PRECISION: 0.7132616487455197 | ACCURACY: 0.6530164533820841
Mean val loss: 0.623

Epoch 2


Training: 100%|██████████| 303/303 [00:05<00:00, 53.40it/s] 


Mean train loss: 0.623


Evaluating: 100%|██████████| 22/22 [00:00<00:00, 972.36it/s]

F1: 0.6618541033434651 | ROC/AUC: 0.6773697602785589 | RECALL: 0.6112280701754386 | PRECISION: 0.7216238608119304 | ACCURACY: 0.6745886654478976
Mean val loss: 0.594





## Testing

In [132]:
best_model = BinaryClassifier(input_dim=len(node_types)).to(DEVICE)
best_model.load_state_dict(torch.load("../../data/generated/best_nn_ast.pth"))


eval_model(test_loader, best_model, criterion)

Evaluating: 100%|██████████| 20/20 [00:00<00:00, 62.77it/s]

F1: 0.6580375782881002 | ROC/AUC: 0.6733400767714494 | RECALL: 0.6080246913580247 | PRECISION: 0.7170154686078253 | ACCURACY: 0.670555108608206





(0.5991495251655579, 0.6080246913580247)