# Solution based on AST

In [132]:
from collections import Counter

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import tree_sitter_python as tspython
from sklearn.metrics import (
    f1_score,
    mean_absolute_error,
    mean_squared_error,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tree_sitter import Language, Parser

PY_LANGUAGE = Language(tspython.language())
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

## Utils

In [133]:
def set_seed(seed: int = 420):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

## Data Loading

In [134]:
set_seed()
df = pd.read_csv("../../data/generated/dataset_dist.csv")

print(f"Total size: {len(df)}\n")

train_df, val_prep = train_test_split(df, test_size=0.3, stratify=df["generated"])
valid_df, test_df = train_test_split(
    val_prep, test_size=0.3, stratify=val_prep["generated"]
)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(valid_df)}")
print(f"Test size: {len(test_df)}")

Total size: 12428

Train size: 8699
Validation size: 2610
Test size: 1119


Check that the data is balanced

In [135]:
print(f"Mean generated (train): {train_df['generated'].mean()}")
print(f"Mean generated (validation): {valid_df['generated'].mean()}")
print(f"Mean generated (test): {test_df['generated'].mean()}")

Mean generated (train): 0.5126688125071848
Mean generated (validation): 0.512284674329502
Mean generated (test): 0.5121599642537981


## Dataset building

In [136]:
parser = Parser(PY_LANGUAGE)
node_types = set()


def walk_tree(node, types):
    types.append(node.type)
    for child in node.children:
        walk_tree(child, types)


def code_to_feature_vector(code: bytes, device=DEVICE) -> torch.Tensor:
    tree = parser.parse(code)
    types = []
    walk_tree(tree.root_node, types)
    counts = Counter(types)
    feature_vector = [counts.get(typ, 0) for typ in node_types]
    return torch.tensor(feature_vector, dtype=torch.float32, device=device)


# Gather all node types
for _, row in train_df.iterrows():
    tree = parser.parse(str.encode(row["code"]))
    types = []
    walk_tree(tree.root_node, types)
    node_types.update(types)

node_types = sorted(node_types)
type_to_idx = {typ: i for i, typ in enumerate(node_types)}

Save node types for inference

In [137]:
with open("../../data/ast/node_types.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(node_types))

In [138]:
class ASTDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __getitem__(self, index) -> tuple[str, torch.Tensor, float]:
        raw_code = self.dataframe["code"].iloc[index]
        return (
            raw_code,
            code_to_feature_vector(raw_code.encode("utf-8")),
            float(self.dataframe["generated"].iloc[index]),
        )

    def __len__(self):
        return len(self.dataframe)


data_train = ASTDataset(dataframe=train_df)
dataloader_train = DataLoader(data_train, batch_size=64)
data_val = ASTDataset(dataframe=valid_df)
dataloader_val = DataLoader(data_val, batch_size=128)
data_test = ASTDataset(dataframe=test_df)
dataloader_test = DataLoader(data_test, batch_size=128)

## Model definition

In [139]:
class AIDetector(nn.Module):
    def __init__(self, input_dim: int, hidden_dim1: int = 64, hidden_dim2: int = 32):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim2, 1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

## Training

In [None]:
def compute_metrics(predictions: list[float], labels: list[float]) -> dict:
    predictions_rounded = [round(x) for x in predictions]
    labels_rounded = [round(x) for x in labels]

    return {
        "recall": recall_score(labels_rounded, predictions_rounded),
        "roc_auc": roc_auc_score(labels_rounded, predictions_rounded),
        "f1": f1_score(labels_rounded, predictions_rounded),
        "mae": mean_absolute_error(labels, predictions),
        "mse": mean_squared_error(labels, predictions),
    }


def metrics_str(metrics: dict) -> str:
    return " | ".join([f"{key.upper()}: {value:.4f}" for key, value in metrics.items()])


def train_model(
    model: nn.Module, dataloader: DataLoader, criterion, optimizer: optim.Optimizer
):
    losses = []
    model.train()
    for _, code, label in tqdm(dataloader, desc="Training"):
        code, label = code.float().to(DEVICE), label.float().to(DEVICE)
        outputs = model(code)
        outputs = outputs.squeeze()
        loss = criterion(outputs, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return np.mean(losses)


def evaluate_model(model: nn.Module, dataloader: DataLoader) -> dict:
    model.eval()
    with torch.no_grad():
        all_predictions = []
        all_truths = []
        for _, code, label in tqdm(dataloader, desc="Validation"):
            code, label = code.float().to(DEVICE), label.float().to(DEVICE)
            outputs = model(code)
            outputs = outputs.squeeze()

            all_predictions.extend(outputs.detach().cpu().numpy().tolist())
            all_truths.extend(label.detach().cpu().numpy().tolist())

        return compute_metrics(all_predictions, all_truths)


def train_eval_loop(
    model: nn.Module,
    dataloader_train: DataLoader,
    dataloader_val: DataLoader,
    criterion,
    optimizer: optim.Optimizer,
    epochs: int = 5,
    early_stopping: int = 3,
    maximize: str = "recall",
    save_path: str = "../../data/ast/best_model.pth",
):
    best_score = 0 if maximize == "recall" else float("inf")
    no_improvement = 0
    mean_losses = []
    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}/{epochs}")
        mean_loss = train_model(model, dataloader_train, criterion, optimizer)
        metrics = evaluate_model(model, dataloader_val)

        mean_losses.append(mean_loss)
        print(f"\n{metrics_str(metrics)}\n")

        score = metrics[maximize]
        if (maximize == "recall" and score > best_score) or (
            maximize == "mae" and score < best_score
        ):
            no_improvement = 0
            best_score = score
            torch.save(model.state_dict(), save_path)
        else:
            no_improvement += 1
        if no_improvement >= early_stopping:
            print("Early stopping triggered.")
            break

    return mean_losses


set_seed()
model = AIDetector(input_dim=len(node_types)).to(DEVICE)
# criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

losses = train_eval_loop(
    model,
    dataloader_train,
    dataloader_val,
    criterion,
    optimizer,
    epochs=50,
    early_stopping=5,
    maximize="mae",
)

Epoch 1/50


Training: 100%|██████████| 136/136 [00:08<00:00, 15.18it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.79it/s]



RECALL: 0.4721 | ROC_AUC: 0.6176 | F1: 0.5587 | MAE: 0.3745 | MSE: 0.2153

Epoch 2/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.73it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.78it/s]



RECALL: 0.6390 | ROC_AUC: 0.6615 | F1: 0.6623 | MAE: 0.3323 | MSE: 0.1977

Epoch 3/50


Training: 100%|██████████| 136/136 [00:05<00:00, 25.62it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 12.21it/s]



RECALL: 0.6243 | ROC_AUC: 0.6785 | F1: 0.6677 | MAE: 0.3218 | MSE: 0.1961

Epoch 4/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.49it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.02it/s]



RECALL: 0.4860 | ROC_AUC: 0.6702 | F1: 0.6001 | MAE: 0.3311 | MSE: 0.2152

Epoch 5/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.88it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.02it/s]



RECALL: 0.5522 | ROC_AUC: 0.6849 | F1: 0.6422 | MAE: 0.3130 | MSE: 0.1999

Epoch 6/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.98it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.35it/s]



RECALL: 0.6096 | ROC_AUC: 0.6916 | F1: 0.6707 | MAE: 0.3007 | MSE: 0.1902

Epoch 7/50


Training: 100%|██████████| 136/136 [00:05<00:00, 24.79it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.36it/s]



RECALL: 0.6213 | ROC_AUC: 0.7019 | F1: 0.6823 | MAE: 0.2944 | MSE: 0.1872

Epoch 8/50


Training: 100%|██████████| 136/136 [00:04<00:00, 28.56it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.98it/s]



RECALL: 0.6449 | ROC_AUC: 0.7064 | F1: 0.6941 | MAE: 0.2894 | MSE: 0.1844

Epoch 9/50


Training: 100%|██████████| 136/136 [00:05<00:00, 24.87it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.45it/s]



RECALL: 0.6566 | ROC_AUC: 0.7103 | F1: 0.7009 | MAE: 0.2850 | MSE: 0.1817

Epoch 10/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.40it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.66it/s]



RECALL: 0.6551 | ROC_AUC: 0.7144 | F1: 0.7032 | MAE: 0.2813 | MSE: 0.1803

Epoch 11/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.35it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.74it/s]



RECALL: 0.6860 | ROC_AUC: 0.7186 | F1: 0.7166 | MAE: 0.2771 | MSE: 0.1765

Epoch 12/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.38it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.54it/s]



RECALL: 0.6904 | ROC_AUC: 0.7280 | F1: 0.7245 | MAE: 0.2752 | MSE: 0.1762

Epoch 13/50


Training: 100%|██████████| 136/136 [00:04<00:00, 28.27it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.41it/s]



RECALL: 0.7044 | ROC_AUC: 0.7350 | F1: 0.7338 | MAE: 0.2712 | MSE: 0.1742

Epoch 14/50


Training: 100%|██████████| 136/136 [00:04<00:00, 32.61it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.50it/s]



RECALL: 0.6978 | ROC_AUC: 0.7373 | F1: 0.7334 | MAE: 0.2697 | MSE: 0.1751

Epoch 15/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.83it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.58it/s]



RECALL: 0.7007 | ROC_AUC: 0.7388 | F1: 0.7353 | MAE: 0.2682 | MSE: 0.1772

Epoch 16/50


Training: 100%|██████████| 136/136 [00:04<00:00, 28.12it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.68it/s]



RECALL: 0.6868 | ROC_AUC: 0.7374 | F1: 0.7300 | MAE: 0.2678 | MSE: 0.1752

Epoch 17/50


Training: 100%|██████████| 136/136 [00:05<00:00, 23.90it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 13.97it/s]



RECALL: 0.6956 | ROC_AUC: 0.7458 | F1: 0.7388 | MAE: 0.2644 | MSE: 0.1722

Epoch 18/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.54it/s]
Validation: 100%|██████████| 21/21 [00:02<00:00,  7.76it/s]



RECALL: 0.6669 | ROC_AUC: 0.7443 | F1: 0.7285 | MAE: 0.2683 | MSE: 0.1759

Epoch 19/50


Training: 100%|██████████| 136/136 [00:08<00:00, 16.32it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 13.03it/s]



RECALL: 0.6684 | ROC_AUC: 0.7442 | F1: 0.7289 | MAE: 0.2645 | MSE: 0.1748

Epoch 20/50


Training: 100%|██████████| 136/136 [00:05<00:00, 24.01it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.73it/s]



RECALL: 0.6787 | ROC_AUC: 0.7465 | F1: 0.7340 | MAE: 0.2620 | MSE: 0.1729

Epoch 21/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.57it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 13.11it/s]



RECALL: 0.6890 | ROC_AUC: 0.7477 | F1: 0.7381 | MAE: 0.2607 | MSE: 0.1717

Epoch 22/50


Training: 100%|██████████| 136/136 [00:08<00:00, 15.73it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.52it/s]



RECALL: 0.6860 | ROC_AUC: 0.7498 | F1: 0.7387 | MAE: 0.2592 | MSE: 0.1715

Epoch 23/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.30it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.17it/s]



RECALL: 0.7125 | ROC_AUC: 0.7574 | F1: 0.7523 | MAE: 0.2514 | MSE: 0.1657

Epoch 24/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.06it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.44it/s]



RECALL: 0.6721 | ROC_AUC: 0.7480 | F1: 0.7330 | MAE: 0.2591 | MSE: 0.1719

Epoch 25/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.99it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 14.27it/s]



RECALL: 0.7449 | ROC_AUC: 0.7592 | F1: 0.7628 | MAE: 0.2498 | MSE: 0.1665

Epoch 26/50


Training: 100%|██████████| 136/136 [00:04<00:00, 28.52it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.97it/s]



RECALL: 0.7066 | ROC_AUC: 0.7601 | F1: 0.7525 | MAE: 0.2533 | MSE: 0.1694

Epoch 27/50


Training: 100%|██████████| 136/136 [00:04<00:00, 28.45it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 18.11it/s]



RECALL: 0.7206 | ROC_AUC: 0.7595 | F1: 0.7562 | MAE: 0.2518 | MSE: 0.1678

Epoch 28/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.15it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.91it/s]



RECALL: 0.7221 | ROC_AUC: 0.7630 | F1: 0.7592 | MAE: 0.2465 | MSE: 0.1642

Epoch 29/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.63it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.65it/s]



RECALL: 0.7324 | ROC_AUC: 0.7626 | F1: 0.7618 | MAE: 0.2461 | MSE: 0.1643

Epoch 30/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.61it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.10it/s]



RECALL: 0.7390 | ROC_AUC: 0.7595 | F1: 0.7614 | MAE: 0.2470 | MSE: 0.1658

Epoch 31/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.64it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.79it/s]



RECALL: 0.6926 | ROC_AUC: 0.7611 | F1: 0.7491 | MAE: 0.2526 | MSE: 0.1670

Epoch 32/50


Training: 100%|██████████| 136/136 [00:04<00:00, 27.40it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.94it/s]



RECALL: 0.7331 | ROC_AUC: 0.7657 | F1: 0.7643 | MAE: 0.2454 | MSE: 0.1628

Epoch 33/50


Training: 100%|██████████| 136/136 [00:05<00:00, 26.93it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.69it/s]



RECALL: 0.7176 | ROC_AUC: 0.7656 | F1: 0.7598 | MAE: 0.2451 | MSE: 0.1636

Epoch 34/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.14it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.22it/s]



RECALL: 0.6934 | ROC_AUC: 0.7667 | F1: 0.7535 | MAE: 0.2511 | MSE: 0.1686

Epoch 35/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.09it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 15.52it/s]



RECALL: 0.7132 | ROC_AUC: 0.7598 | F1: 0.7543 | MAE: 0.2503 | MSE: 0.1678

Epoch 36/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.72it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.17it/s]



RECALL: 0.7647 | ROC_AUC: 0.7672 | F1: 0.7738 | MAE: 0.2429 | MSE: 0.1628

Epoch 37/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.28it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.73it/s]



RECALL: 0.7941 | ROC_AUC: 0.7687 | F1: 0.7823 | MAE: 0.2401 | MSE: 0.1627

Epoch 38/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.02it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 18.11it/s]



RECALL: 0.8103 | ROC_AUC: 0.7675 | F1: 0.7855 | MAE: 0.2382 | MSE: 0.1596

Epoch 39/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.10it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.91it/s]



RECALL: 0.7919 | ROC_AUC: 0.7720 | F1: 0.7841 | MAE: 0.2365 | MSE: 0.1587

Epoch 40/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.71it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.29it/s]



RECALL: 0.7757 | ROC_AUC: 0.7715 | F1: 0.7797 | MAE: 0.2354 | MSE: 0.1584

Epoch 41/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.98it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.23it/s]



RECALL: 0.7941 | ROC_AUC: 0.7775 | F1: 0.7886 | MAE: 0.2328 | MSE: 0.1556

Epoch 42/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.96it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.23it/s]



RECALL: 0.8154 | ROC_AUC: 0.7785 | F1: 0.7944 | MAE: 0.2296 | MSE: 0.1531

Epoch 43/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.93it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.37it/s]



RECALL: 0.8316 | ROC_AUC: 0.7786 | F1: 0.7982 | MAE: 0.2314 | MSE: 0.1558

Epoch 44/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.82it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.52it/s]



RECALL: 0.8404 | ROC_AUC: 0.7802 | F1: 0.8013 | MAE: 0.2299 | MSE: 0.1569

Epoch 45/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.45it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.95it/s]



RECALL: 0.8287 | ROC_AUC: 0.7811 | F1: 0.7993 | MAE: 0.2292 | MSE: 0.1548

Epoch 46/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.39it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 16.55it/s]



RECALL: 0.8382 | ROC_AUC: 0.7859 | F1: 0.8048 | MAE: 0.2270 | MSE: 0.1543

Epoch 47/50


Training: 100%|██████████| 136/136 [00:04<00:00, 29.29it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.41it/s]



RECALL: 0.7978 | ROC_AUC: 0.7857 | F1: 0.7955 | MAE: 0.2272 | MSE: 0.1519

Epoch 48/50


Training: 100%|██████████| 136/136 [00:04<00:00, 32.20it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.28it/s]



RECALL: 0.8206 | ROC_AUC: 0.7867 | F1: 0.8014 | MAE: 0.2232 | MSE: 0.1488

Epoch 49/50


Training: 100%|██████████| 136/136 [00:04<00:00, 31.16it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.06it/s]



RECALL: 0.8213 | ROC_AUC: 0.7859 | F1: 0.8010 | MAE: 0.2249 | MSE: 0.1511

Epoch 50/50


Training: 100%|██████████| 136/136 [00:04<00:00, 30.20it/s]
Validation: 100%|██████████| 21/21 [00:01<00:00, 17.29it/s]


RECALL: 0.8309 | ROC_AUC: 0.7854 | F1: 0.8028 | MAE: 0.2235 | MSE: 0.1505






## Testing

In [145]:
def test_model(model: nn.Module, dataloader: DataLoader) -> tuple[pd.DataFrame, dict]:
    model.eval()
    with torch.no_grad():
        all_predictions = []
        all_truths = []
        all_codes = []
        for real_code, code, label in tqdm(dataloader, desc="Validation"):
            code, label = code.float().to(DEVICE), label.float().to(DEVICE)
            outputs = model(code)
            outputs = outputs.squeeze()

            all_predictions.extend(outputs.detach().cpu().numpy().tolist())
            all_truths.extend(label.detach().cpu().numpy().tolist())
            all_codes.extend(real_code)

        test_df = pd.DataFrame(
            {"code": all_codes, "real": all_truths, "predicted": all_predictions}
        )
        return test_df, compute_metrics(all_predictions, all_truths)

In [146]:
best_model = AIDetector(input_dim=len(node_types)).to(DEVICE)
best_model.load_state_dict(torch.load("../../data/ast/best_model.pth"))
test_df, test_metrics = test_model(best_model, dataloader_test)
print(metrics_str(test_metrics))

test_df.to_csv("../../data/ast/test_results.csv")

Validation: 100%|██████████| 9/9 [00:00<00:00, 13.53it/s]

RECALL: 0.7955 | ROC_AUC: 0.7935 | F1: 0.8003 | MAE: 0.2208 | MSE: 0.1495





## Inference

In [147]:
def detect_ai_code(code: str) -> float:
    with open("../../data/ast/node_types.txt", "r", encoding="utf-8") as f:
        node_types_loaded = f.readlines()

    loaded_model = AIDetector(input_dim=len(node_types_loaded))
    loaded_model.load_state_dict(torch.load("../../data/ast/best_model.pth"))

    code_vectorized = code_to_feature_vector(
        code.encode("utf-8"), device=torch.device("cpu")
    ).unsqueeze(0)
    with torch.no_grad():
        prediction = loaded_model(code_vectorized).squeeze().cpu().item()
    return prediction

In [148]:
code1 = """
a,b = map(int, input().split())
if a > b:
    return 1
return 0
"""

code2 = """
x, y = map(int, input().split())
return int(x>y)
"""

code3 = """
l = map(int, input().split())
if l[0] > l[1] :
    return 1
else:
    return 0
"""
for c in [code1, code2, code3]:
    print(f"{'*' * 15}\nCODE:\n{c}\nPREDICTION: {detect_ai_code(c):.4f}\n")

***************
CODE:

a,b = map(int, input().split())
if a > b:
    return 1
return 0

PREDICTION: 0.5264

***************
CODE:

x, y = map(int, input().split())
return int(x>y)

PREDICTION: 0.9335

***************
CODE:

l = map(int, input().split())
if l[0] > l[1] :
    return 1
else:
    return 0

PREDICTION: 0.9611

