In [10]:
import json
import os

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

from src.classes.utils.DebugLogger import DebugLogger

logger = DebugLogger(use_panel_for_errors=True)

In [11]:
def accuracy(path_to_log, gt_category):
    files = os.listdir(path_to_log)
    correct = 0
    total = len(files)

    for filename in files:
        path_to_file = os.path.join(path_to_log, filename)
        content = json.load(open(path_to_file, 'r', encoding='utf-8'))
        classification = content["classification"]
        print(classification)
        if classification.lower() == gt_category:
            correct += 1

    accuracy = correct / total
    logger.info(f"Classification Accuracy: {accuracy:.2%}")

In [12]:
path_to_log = "../log/test_reentrant"
accuracy(path_to_log, gt_category="reentrant")

Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Safe
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Safe
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant


In [13]:
path_to_log = "../log/test_safe"
accuracy(path_to_log, gt_category="safe")

Safe
Safe
Safe
Safe
Safe
Safe
Reentrant
Reentrant
Safe
Safe
Safe
Safe
Safe
Safe
Reentrant
Reentrant
Reentrant
Safe
Safe
Safe
Safe
Reentrant
Safe
Reentrant
Safe
Safe
Safe
Reentrant
Reentrant
Safe
Reentrant
Safe
Safe
Reentrant
Reentrant
Safe
Reentrant
Safe
Safe
Reentrant
Reentrant
Reentrant
Reentrant
Reentrant
Safe
Safe
Reentrant
Reentrant
Reentrant
Safe
Safe
Safe
Reentrant
Reentrant
Reentrant
Safe
Safe
Safe
Reentrant
Reentrant
Reentrant
Safe
Reentrant
Safe
Reentrant
Reentrant
Reentrant
Reentrant
Safe
Safe
Safe
Safe
Reentrant
Safe


In [14]:
def load_json(file_path):
    """Load a JSON file and return its contents."""
    logger.debug(f"Loading JSON file: {file_path}")
    with open(file_path, 'r') as f:
        return json.load(f)


def get_ground_truth_from_path(contract_path):
    """Determine the ground truth label based on the parent folder name."""
    label = None
    if "reentrant" in contract_path:
        label = "reentrant"
    elif "safe" in contract_path:
        label = "safe"
    logger.debug(f"Determined ground truth for {contract_path}: {label}")
    return label


def collect_predictions(contract_dir):
    """Collect labels from the 3 JSON files (excluding classification.json)."""
    logger.debug(f"Collecting predictions from directory: {contract_dir}")
    predictions = []
    for file_name in os.listdir(contract_dir):
        file_path = os.path.join(contract_dir, file_name)
        if file_name.endswith(".json") and file_name != "classification.json":
            data = load_json(file_path)
            if "label" in data:
                predictions.append(data["label"])
    logger.debug(f"Collected predictions: {predictions}")
    return predictions


def evaluate(base_dir):
    """Compute accuracy, precision, recall, and F1-score for contract classification."""
    logger.info(f"Starting evaluation on base directory: {base_dir}")
    y_true_analysis = []
    y_pred_analysis = []
    y_true_classification = []
    y_pred_classification = []

    contract_paths = [os.path.join(root) for root, _, _ in os.walk(base_dir) if
                      any(sub in root for sub in ["reentrant", "safe"])]

    for root in tqdm(contract_paths, desc="Processing contracts"):
        ground_truth = get_ground_truth_from_path(root)
        classification_file = os.path.join(root, "classification.json")

        if os.path.exists(classification_file):
            logger.debug(f"Processing classification file: {classification_file}")
            classification_data = load_json(classification_file)
            contract_label = classification_data.get("classification").lower()

            if contract_label:
                y_true_classification.append(ground_truth)
                y_pred_classification.append(contract_label)
                logger.debug(f"Classification result: {contract_label}, Ground truth: {ground_truth}")

            predictions = collect_predictions(root)
            y_true_analysis.extend([ground_truth] * len(predictions))
            y_pred_analysis.extend(predictions)

    # Compute metrics for analysis files
    accuracy_analysis = accuracy_score(y_true_analysis, y_pred_analysis)
    precision_analysis = precision_score(y_true_analysis, y_pred_analysis, pos_label="reentrant", average='binary')
    recall_analysis = recall_score(y_true_analysis, y_pred_analysis, pos_label="reentrant", average='binary')
    f1_analysis = f1_score(y_true_analysis, y_pred_analysis, pos_label="reentrant", average='binary')

    # Compute metrics for classification file
    accuracy_classification = accuracy_score(y_true_classification, y_pred_classification)
    precision_classification = precision_score(y_true_classification, y_pred_classification, pos_label="reentrant",
                                               average='binary')
    recall_classification = recall_score(y_true_classification, y_pred_classification, pos_label="reentrant",
                                         average='binary')
    f1_classification = f1_score(y_true_classification, y_pred_classification, pos_label="reentrant", average='binary')

    # Log results
    logger.info("Analysis Files Metrics:")
    logger.info(f"Accuracy: {accuracy_analysis:.4f}")
    logger.info(f"Precision: {precision_analysis:.4f}")
    logger.info(f"Recall: {recall_analysis:.4f}")
    logger.info(f"F1 Score: {f1_analysis:.4f}")

    logger.info("\nClassification File Metrics:")
    logger.info(f"Accuracy: {accuracy_classification:.4f}")
    logger.info(f"Precision: {precision_classification:.4f}")
    logger.info(f"Recall: {recall_classification:.4f}")
    logger.info(f"F1 Score: {f1_classification:.4f}")

    return {
        "analysis": (accuracy_analysis, precision_analysis, recall_analysis, f1_analysis),
        "classification": (accuracy_classification, precision_classification, recall_classification, f1_classification)
    }

In [15]:
base_directory = "../log/contracts_analysis_both_20250225_104133"  # Change this to your actual base directory
evaluate(base_directory)

Processing contracts:   0%|          | 0/96 [00:00<?, ?it/s]

Processing contracts:   4%|▍         | 4/96 [00:00<00:03, 24.49it/s]

Processing contracts:   7%|▋         | 7/96 [00:00<00:03, 24.50it/s]

Processing contracts:  14%|█▎        | 13/96 [00:00<00:02, 37.72it/s]

Processing contracts:  19%|█▉        | 18/96 [00:00<00:02, 38.90it/s]

Processing contracts:  24%|██▍       | 23/96 [00:00<00:02, 28.33it/s]

Processing contracts:  29%|██▉       | 28/96 [00:00<00:02, 32.70it/s]

Processing contracts:  34%|███▍      | 33/96 [00:00<00:01, 35.87it/s]

Processing contracts:  43%|████▎     | 41/96 [00:01<00:01, 45.81it/s]

Processing contracts:  52%|█████▏    | 50/96 [00:01<00:00, 56.91it/s]

Processing contracts:  65%|██████▍   | 62/96 [00:01<00:00, 70.70it/s]

Processing contracts:  75%|███████▌  | 72/96 [00:01<00:00, 77.70it/s]

Processing contracts:  85%|████████▌ | 82/96 [00:01<00:00, 83.75it/s]

Processing contracts:  95%|█████████▍| 91/96 [00:01<00:00, 83.69it/s]

Processing contracts: 100%|██████████| 96/96 [00:01<00:00, 56.98it/s]


{'analysis': (1.0, 1.0, 1.0, 1.0),
 'classification': (0.723404255319149,
  0.78125,
  0.5681818181818182,
  0.6578947368421053)}