In [4]:
!pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
from inference import *

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging

from transformers import AutoTokenizer
import json
import os
import pandas as pd
from typing import Dict, List, Tuple, Any
import re
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score

  _torch_pytree._register_pytree_node(


# 1. Load Validation Dataset:

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to codeBert input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

class SmartContractVulnerabilityDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 1024,
        split: str = "train",
        vulnerability_types: List[str] = None
    ):
        """
        Args:
            data_path: Path to the CSV file containing the dataset
            tokenizer: Tokenizer for encoding the source code
            max_length: Maximum sequence length
            split: "train" or "val" to specify which split to load
            vulnerability_types: List of vulnerability types to consider
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        self.vulnerability_types = vulnerability_types or [
            'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
        ]
        
        # Load the dataset
        self.data = self._load_dataset(data_path)
        
    def _load_dataset(self, data_path: str) -> List[Dict]:
        """Load and preprocess the dataset from CSV"""
        dataset = []
        
        # Read the CSV file
        df = pd.read_csv(data_path)
        
        # Split into train/val if needed
        if self.split == "train":
            df = df.sample(frac=0.8, random_state=42)
        else:
            df = df.sample(frac=0.2, random_state=42)
        
        # Process each contract
        for _, row in df.iterrows():
            try:
                source_code = row['source_code']
                contract_name = row['contract_name']
                
                # Parse AST and get paths
                ast = parse_solidity_to_ast(source_code)
                ast_paths = prepare_code2vec_input(ast) if ast else []
                ast_path_text = ' '.join(ast_paths)
                
                # Split source code into lines
                lines = source_code.split('\n')
                
                # Create token-to-line mapping
                token_to_line = []
                current_line = 0
                
                # Tokenize each line separately to maintain mapping
                for line in lines:
                    line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
                    token_to_line.extend([current_line] * len(line_tokens))
                    current_line += 1
                
                # Add special tokens
                token_to_line = [0] + token_to_line + [0]  # [CLS] and [SEP] tokens
                
                # Truncate if too long
                if len(token_to_line) > self.max_length:
                    token_to_line = token_to_line[:self.max_length]
                
                # Pad if too short
                if len(token_to_line) < self.max_length:
                    token_to_line.extend([0] * (self.max_length - len(token_to_line)))
                
                # Create multi-label line labels for each vulnerability type
                line_labels = self._create_multi_label_line_labels(source_code, row)
                
                # Create contract-level vulnerability labels
                contract_labels = self._create_contract_vulnerability_labels(row)
                
                # Tokenize the source code
                encoding = self.tokenizer(
                    source_code,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Tokenize AST paths
                ast_encoding = self.tokenizer(
                    ast_path_text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Convert line labels to tensor and ensure consistent shape
                vuln_tensor = torch.zeros((len(self.vulnerability_types), self.max_length), dtype=torch.long)
                for i, labels in enumerate(line_labels):
                    if len(labels) > self.max_length:
                        labels = labels[:self.max_length]
                    vuln_tensor[i, :len(labels)] = torch.tensor(labels, dtype=torch.long)
                
                # Convert contract labels to tensor
                contract_vuln_tensor = torch.tensor(contract_labels, dtype=torch.long)
                
                # Convert token_to_line to tensor
                token_to_line_tensor = torch.tensor(token_to_line, dtype=torch.long)
                
                # Ensure attention masks are boolean
                attention_mask = encoding['attention_mask'].squeeze(0).bool()
                ast_attention_mask = ast_encoding['attention_mask'].squeeze(0).bool()
                
                # Ensure input_ids are the right length
                input_ids = encoding['input_ids'].squeeze(0)
                ast_input_ids = ast_encoding['input_ids'].squeeze(0)
                
                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                if len(ast_input_ids) > self.max_length:
                    ast_input_ids = ast_input_ids[:self.max_length]
                
                # Pad if necessary
                if len(input_ids) < self.max_length:
                    input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - len(input_ids)))
                if len(ast_input_ids) < self.max_length:
                    ast_input_ids = torch.nn.functional.pad(ast_input_ids, (0, self.max_length - len(ast_input_ids)))
                
                dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'ast_input_ids': ast_input_ids,
                    'ast_attention_mask': ast_attention_mask,
                    'vulnerable_lines': vuln_tensor,
                    'contract_vulnerabilities': contract_vuln_tensor,
                    'token_to_line': token_to_line_tensor,
                    'source_code': source_code,
                    'contract_name': contract_name
                })
            except Exception as e:
                print(f"Error processing contract {contract_name}: {str(e)}")
                continue
        
        return dataset
    
    def _create_contract_vulnerability_labels(self, row: pd.Series) -> List[int]:
        """Create contract-level vulnerability labels"""
        contract_labels = []
        for vuln_type in self.vulnerability_types:
            # Check if contract has this vulnerability type
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Contract is vulnerable if it has any vulnerable lines
            has_vulnerability = len(vuln_lines) > 0
            contract_labels.append(1 if has_vulnerability else 0)
        
        return contract_labels
    
    def _create_multi_label_line_labels(self, source_code: str, row: pd.Series) -> List[List[int]]:
        """Create multi-label line labels for each vulnerability type"""
        total_lines = len(source_code.split('\n'))
        line_labels = {vuln_type: [0] * total_lines for vuln_type in self.vulnerability_types}
        
        # Process each vulnerability type
        for vuln_type in self.vulnerability_types:
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Process each vulnerable line/snippet
            for line_or_snippet in vuln_lines:
                if isinstance(line_or_snippet, int):
                    # If it's a line number, mark that line
                    if 0 <= line_or_snippet < total_lines:
                        line_labels[vuln_type][line_or_snippet] = 1
                else:
                    # If it's a code snippet, find matching lines
                    source_lines = source_code.split('\n')
                    for i, line in enumerate(source_lines):
                        # Clean both the line and snippet for comparison
                        clean_line = re.sub(r'\s+', ' ', line.strip())
                        clean_snippet = re.sub(r'\s+', ' ', str(line_or_snippet).strip())
                        if clean_snippet in clean_line:
                            line_labels[vuln_type][i] = 1
        
        # Convert to list format
        return [line_labels[vuln_type] for vuln_type in self.vulnerability_types]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable length inputs
    """
    # Get the maximum length in this batch for each type of tensor
    max_input_len = max(item['input_ids'].size(0) for item in batch)
    
    # Pad all tensors to their respective maximum lengths
    padded_batch = {
        'input_ids': torch.stack([
            torch.nn.functional.pad(item['input_ids'], (0, max_input_len - item['input_ids'].size(0)))
            for item in batch
        ]),
        'attention_mask': torch.stack([
            torch.nn.functional.pad(item['attention_mask'], (0, max_input_len - item['attention_mask'].size(0)))
            for item in batch
        ]),
        'ast_input_ids': torch.stack([item['ast_input_ids'] for item in batch]),
        'ast_attention_mask': torch.stack([item['ast_attention_mask'] for item in batch]),
        'vulnerable_lines': torch.stack([item['vulnerable_lines'] for item in batch]),
        'contract_vulnerabilities': torch.stack([item['contract_vulnerabilities'] for item in batch]),
        'token_to_line': torch.stack([item['token_to_line'] for item in batch]),
        'source_code': [item['source_code'] for item in batch],
        'contract_name': [item['contract_name'] for item in batch]
    }
    
    return padded_batch

def create_dataloaders(
    data_path: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 1024,
    num_workers: int = 4,
    vulnerability_types: List[str] = None
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Create train and validation dataloaders
    
    Args:
        data_path: Path to the CSV file containing the dataset
        tokenizer: Tokenizer for encoding the source code
        batch_size: Batch size for training
        max_length: Maximum sequence length
        num_workers: Number of workers for data loading
        vulnerability_types: List of vulnerability types to consider
    
    Returns:
        Tuple of (train_dataloader, val_dataloader)
    """
    # Create datasets
    train_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="train",
        vulnerability_types=vulnerability_types
    )
    
    val_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="val",
        vulnerability_types=vulnerability_types
    )
    
    # Create dataloaders with custom collate function
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    return train_dataloader, val_dataloader

In [3]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

# Create dataloaders
train_dataloader, val_dataloader = create_dataloaders(
    data_path="contract_sources_with_vulnerabilities_2048_token_size.csv",
    tokenizer=tokenizer,
    batch_size=8,
    max_length=1024,
    vulnerability_types=['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


In [4]:
val_dataloader.dataset.data[0]['vulnerable_lines'][7]

tensor([0, 0, 0,  ..., 0, 0, 0])

# 2. Load Model and SmartContractAnalyzer

In [5]:
def calculate_precision(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate precision for vulnerability detection."""
    if np.sum(pred_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(pred_labels)

def calculate_recall(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate recall for vulnerability detection."""
    if np.sum(true_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(true_labels)

def calculate_f1_score(precision: float, recall: float) -> float:
    """Calculate F1 score."""
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)

def calculate_line_accuracy(true_line_vulns: np.ndarray, pred_line_vulns: Dict) -> float:
    """Calculate line-level accuracy (simplified)."""
    try:
        # Convert predicted line vulnerabilities to array format
        pred_array = np.zeros_like(true_line_vulns)
        for line_num, line_vulns in pred_line_vulns.items():
            if line_num < pred_array.shape[0]:
                for vuln_idx, is_vuln in enumerate(line_vulns.values()):
                    if vuln_idx < pred_array.shape[1]:
                        pred_array[line_num, vuln_idx] = 1 if is_vuln else 0
        
        return np.mean(pred_array == true_line_vulns)
    except:
        return 0.

def get_vulnerability_details(analyzer, true_vulns: np.ndarray, pred_vulns: np.ndarray, 
                            probabilities: List[float]) -> Dict[str, Any]:
    """Get detailed vulnerability analysis."""
    details = {}
    for i, vuln_type in enumerate(analyzer.vulnerability_types):
        details[vuln_type] = {
            'true_positive': bool(true_vulns[i] == 1 and pred_vulns[i] == 1),
            'false_positive': bool(true_vulns[i] == 0 and pred_vulns[i] == 1),
            'false_negative': bool(true_vulns[i] == 1 and pred_vulns[i] == 0),
            'true_negative': bool(true_vulns[i] == 0 and pred_vulns[i] == 0),
            'probability': probabilities[i],
            'true_label': int(true_vulns[i]),
            'predicted_label': int(pred_vulns[i])
        }
    return details
    

In [6]:
DATA_PATH = "contract_sources_with_vulnerabilities_2048_token_size.csv"
MODEL_PATH = "checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt"
TOKENIZER_NAME = "microsoft/codebert-base"
VULNERABILITY_TYPES = ['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
DEVICE = torch.device("cuda")
MODEL_LINE_CODE_VULNERABILITY_THRESHOLD = 0.2

In [7]:
analyzer = SmartContractAnalyzer(
    model_path=MODEL_PATH,
    device="cuda" if torch.cuda.is_available() else "cpu",
    use_gan=True
)



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights
Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt
Training epoch: 106
Best validation loss: 0.773994717746973
Training config: GAN=True


In [19]:
print(val_dataloader.dataset.data[5]['source_code'])

/**
 * Source Code first verified at https://etherscan.io on Tuesday, December 18, 2018
 (UTC) */

pragma solidity ^0.4.23;

library SafeMath {

  /**
  * @dev Multiplies two numbers, throws on overflow.
  */
  function mul(uint256 a, uint256 b) internal pure returns (uint256 c) {

    if (a == 0) {
      return 0;
    }

    c = a * b;
    assert(c / a == b);
    return c;
  }

  /**
  * @dev Integer division of two numbers, truncating the quotient.
  */
  function div(uint256 a, uint256 b) internal pure returns (uint256) {

    return a / b;
  }

  /**
  * @dev Subtracts two numbers, throws on overflow (i.e. if subtrahend is greater than minuend).
  */
  function sub(uint256 a, uint256 b) internal pure returns (uint256) {
    assert(b <= a);
    return a - b;
  }

  /**
  * @dev Adds two numbers, throws on overflow.
  */
  function add(uint256 a, uint256 b) internal pure returns (uint256 c) {
    c = a + b;
    assert(c >= a);
    return c;
  }
}


contract ERC20Basic {
    
  functi

In [8]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from notebook_utils import (
    collect_validation_results,
    compute_contract_level_metrics,
    compute_line_level_metrics
)
from inference import SmartContractAnalyzer

def get_line_positions(validation_results, threshold=0.2):
    """
    Get line positions of vulnerable lines in array format.
    """
    print(f"🔍 Analyzing line positions with threshold {threshold}...")
    
    line_positions = {
        'contracts': [],
        'summary': {
            'total_contracts': len(validation_results['line_level']['true_labels']),
            'total_true_vulnerable_lines': 0,
            'total_detected_vulnerable_lines': 0,
            'total_correct_detections': 0
        }
    }
    
    vuln_types = validation_results['metadata']['vulnerability_types']
    
    for contract_idx in range(len(validation_results['line_level']['true_labels'])):
        try:
            # Get data for this contract
            true_labels = validation_results['line_level']['true_labels'][contract_idx]  # Shape: (8, 1024)
            pred_labels = validation_results['line_level']['predicted_labels'][contract_idx]  # Shape: (8, 1024)
            
            # Get source code lines
            if contract_idx < len(validation_results['line_level']['line_mappings']):
                source_lines = validation_results['line_level']['line_mappings'][contract_idx]
            else:
                source_lines = [f"Line {i}" for i in range(1024)]
            
            contract_name = validation_results['contract_level']['contract_names'][contract_idx]
            
            # Analyze each vulnerability type
            vuln_positions = {}
            
            for vuln_idx, vuln_type in enumerate(vuln_types):
                if vuln_idx < true_labels.shape[0]:
                    # Get true and predicted vulnerabilities for this type
                    true_vulns = true_labels[vuln_idx, :]  # Shape: (1024,)
                    pred_vulns = pred_labels[vuln_idx, :]  # Shape: (1024,)
                    
                    # Find true vulnerable line positions
                    true_positions = []
                    for line_idx in range(min(len(source_lines), 1024)):
                        if true_vulns[line_idx] == 1:
                            true_positions.append(line_idx)
                    
                    # Find detected vulnerable line positions
                    detected_positions = []
                    for line_idx in range(min(len(source_lines), 1024)):
                        if pred_vulns[line_idx] == 1:
                            detected_positions.append(line_idx)
                    
                    # Find correctly detected positions (intersection)
                    correct_positions = list(set(true_positions) & set(detected_positions))
                    
                    vuln_positions[vuln_type] = {
                        'TRUE': true_positions,
                        'PREDICTED': detected_positions,
                        'CORRECT': correct_positions,
                        'counts': {
                            'true_count': len(true_positions),
                            'predicted_count': len(detected_positions),
                            'correct_count': len(correct_positions)
                        }
                    }
                    
                    # Update summary
                    line_positions['summary']['total_true_vulnerable_lines'] += len(true_positions)
                    line_positions['summary']['total_detected_vulnerable_lines'] += len(detected_positions)
                    line_positions['summary']['total_correct_detections'] += len(correct_positions)
            
            # Store contract analysis
            contract_analysis = {
                'contract_name': contract_name,
                'contract_idx': contract_idx,
                'vulnerability_positions': vuln_positions,
                'overall_summary': {
                    'total_true_lines': sum(vuln_positions[vuln]['counts']['true_count'] for vuln in vuln_positions),
                    'total_predicted_lines': sum(vuln_positions[vuln]['counts']['predicted_count'] for vuln in vuln_positions),
                    'total_correct_lines': sum(vuln_positions[vuln]['counts']['correct_count'] for vuln in vuln_positions)
                }
            }
            
            line_positions['contracts'].append(contract_analysis)
            
        except Exception as e:
            print(f"  ⚠️  Error analyzing contract {contract_idx}: {str(e)}")
            continue
    
    return line_positions

# Initialize analyzer with latest checkpoint
analyzer = SmartContractAnalyzer(
    model_path=MODEL_PATH,
    device="cuda" if torch.cuda.is_available() else "cpu",
    use_gan=True
)

# Step 1: Collect validation results with specific threshold
threshold = 0.1  # You can change this to any threshold you want
print(f"🚀 Using threshold: {threshold}")

validation_results = collect_validation_results(
    analyzer=analyzer,
    val_dataloader=val_dataloader,
    threshold=threshold,
    max_contracts=10,
    generate_contracts=False  # Skip generation for speed
)

# Step 2: Get line positions
line_positions = get_line_positions(validation_results, threshold=threshold)

# Step 3: Print results in clean array format
print("\n" + "=" * 80)
print(f"📊 LINE POSITION ANALYSIS (THRESHOLD {threshold})")
print("=" * 80)

summary = line_positions['summary']
print(f"\n📈 Overall Statistics:")
print(f"  Total Contracts: {summary['total_contracts']}")
print(f"  Total True Vulnerable Lines: {summary['total_true_vulnerable_lines']}")
print(f"  Total Detected Vulnerable Lines: {summary['total_detected_vulnerable_lines']}")
print(f"  Total Correctly Detected Lines: {summary['total_correct_detections']}")

if summary['total_true_vulnerable_lines'] > 0:
    recall = summary['total_correct_detections'] / summary['total_true_vulnerable_lines']
    print(f"  Overall Recall: {recall:.2%}")

if summary['total_detected_vulnerable_lines'] > 0:
    precision = summary['total_correct_detections'] / summary['total_detected_vulnerable_lines']
    print(f"  Overall Precision: {precision:.2%}")

print(f"\n🔍 LINE POSITIONS BY CONTRACT:")
print("=" * 80)

for contract in line_positions['contracts']:
    print(f"\n📋 Contract: {contract['contract_name']}")
    print(f"  Overall: TRUE={contract['overall_summary']['total_true_lines']}, "
          f"PREDICTED={contract['overall_summary']['total_predicted_lines']}, "
          f"CORRECT={contract['overall_summary']['total_correct_lines']}")
    
    for vuln_type, positions in contract['vulnerability_positions'].items():
        true_count = positions['counts']['true_count']
        predicted_count = positions['counts']['predicted_count']
        correct_count = positions['counts']['correct_count']
        
        if true_count > 0 or predicted_count > 0:
            print(f"  {vuln_type}:")
            print(f"    TRUE: {positions['TRUE']}")
            print(f"    PREDICTED: {positions['PREDICTED']}")
            print(f"    CORRECT: {positions['CORRECT']}")
            print(f"    Counts: TRUE={true_count}, PREDICTED={predicted_count}, CORRECT={correct_count}")

# Step 4: Save results
import json

json_results = {
    'threshold': threshold,
    'summary': line_positions['summary'],
    'contracts': [
        {
            'contract_name': contract['contract_name'],
            'contract_idx': contract['contract_idx'],
            'overall_summary': contract['overall_summary'],
            'vulnerability_positions': contract['vulnerability_positions']
        }
        for contract in line_positions['contracts']
    ]
}

with open(f'line_positions_threshold_{threshold}.json', 'w') as f:
    json.dump(json_results, f, indent=2)

print(f"\n✅ Results saved to: line_positions_threshold_{threshold}.json") 



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights
Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt
Training epoch: 106
Best validation loss: 0.773994717746973
Training config: GAN=True
🚀 Using threshold: 0.1
🚀 Starting comprehensive validation analysis...
📊 Processing 10 contracts...
    Contract 0: actual_lines=90
    Contract 0: pred_line_probs_array shape: (8, 1024)
    Contract 0: pred_line_labels_array shape: (8, 1024)
    Contract 0: found 4 vulnerable lines
    Contract 1: actual_lines=101
    Contract 1: pred_line_probs_array shape: (8, 1024)
    Contract 1: pred_line_labels_array shape: (8, 1024)
    Contract 1: found 7 vulnerable lines
    Contract 2: actual_lines=88
    Contract 2: pred_line_probs_array shape: (8, 1024)
    Contract 2: pred_line_labels_array shape: (8, 1024)

In [11]:
from notebook_utils import (
    debug_line_predictions,
    analyze_vulnerable_line_probabilities,
    check_model_line_predictions,
    print_probability_analysis
)

debug_line_predictions(validation_results, contract_idx=0)

🔍 Debugging line predictions for contract 0...
📊 Contract 0 shapes:
  True labels: (8, 1024)
  Pred labels: (8, 1024)
  Pred probs: (8, 1024)
📏 Actual lines in contract: 90

🔍 Vulnerability Analysis:
  ARTHM:
    True vulnerable lines: 3
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
    Vulnerable line indices: [ 4  5 19]
      Line 4: prob=0.0000, pred=0.0
      Line 5: prob=0.0000, pred=0.0
      Line 19: prob=0.0000, pred=0.0
  DOS:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  LE:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  RENT:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  TimeM:
    True vul

In [8]:
prob_analysis = analyze_vulnerable_line_probabilities(validation_results)
print_probability_analysis(prob_analysis)

📊 Analyzing probability scores for vulnerable lines...
✅ Probability analysis completed!
📊 VULNERABLE LINE PROBABILITY ANALYSIS

📈 Overall Statistics:
  Total Vulnerable Lines: 42
  Mean Probability (Vulnerable Lines): 0.0000
  Mean Probability (All Lines): 0.0000

🎯 Probability Distribution (Vulnerable Lines):
  High Confidence (>0.8): 0 (0.0%)
  Medium Confidence (0.5-0.8): 0 (0.0%)
  Low Confidence (<0.5): 42 (100.0%)

🔍 Per-Vulnerability Type Analysis:
  ARTHM:
    Vulnerable Lines: 41
    Mean Probability (Vulnerable): 0.0000
    Max Probability (Vulnerable): 0.0000
    Min Probability (Vulnerable): 0.0000
    Mean Probability (All Lines): 0.0000

  DOS:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  LE:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  RENT:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  TimeM:
    Vulnerable Lines: 1
    Mean Probability (Vulnerable): 0.0000
    Max Probability (Vulnerable): 0.0000
    

In [9]:
sample_contract = val_dataloader.dataset.data[0]['source_code']
check_model_line_predictions(analyzer, sample_contract)

🔍 Testing model line-level predictions on sample contract...
Contract vuln logits shape: torch.Size([1, 8])
Line vuln logits shape: torch.Size([1, 1024, 8])
Number of lines in contract: 90
Contract predictions shape: (1, 8)
Line predictions shape: (1, 1024, 8)
📊 Model output keys: ['contract_vulnerabilities', 'line_vulnerabilities', 'contract_probabilities', 'line_probabilities']
📋 Line vulnerabilities type: <class 'dict'>
📋 Number of lines with predictions: 90
  Line 0: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 1: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 2: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 3: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False

In [30]:
len(val_dataloader.dataset.data)

506

In [9]:
import torch
import numpy as np
import json
from typing import Dict, List, Any, Tuple
from sklearn.metrics import (
    precision_recall_curve, 
    roc_curve, 
    auc, 
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score
)
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import sys
import os

In [26]:
def compute_metrics_for_threshold(y_true: np.ndarray, y_pred: np.ndarray, y_probs: np.ndarray) -> Dict[str, float]:
    """
    Compute comprehensive metrics for a given threshold.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels (binary)
        y_probs: Predicted probabilities
        
    Returns:
        Dictionary with computed metrics
    """
    if len(y_true) == 0 or np.sum(y_true) == 0:
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'pr_auc': 0.0,
            'roc_auc': 0.0
        }
    
    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # AUC metrics
    try:
        precision_curve, recall_curve, _ = precision_recall_curve(y_true, y_probs)
        pr_auc = auc(recall_curve, precision_curve)
    except:
        pr_auc = 0.0
    
    try:
        fpr, tpr, _ = roc_curve(y_true, y_probs)
        roc_auc = auc(fpr, tpr)
    except:
        roc_auc = 0.0
    
    return {
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'pr_auc': float(pr_auc),
        'roc_auc': float(roc_auc)
    }

def evaluate_model_performance(
    model_path: str,
    val_dataloader: DataLoader,
    contract_thresholds: Dict[str, float] = None,
    line_thresholds: Dict[str, float] = None,
    max_contracts: int = None,
    output_file: str = "comprehensive_evaluation_results.json"
) -> Dict[str, Any]:
    """
    Comprehensive evaluation function that computes detailed performance metrics.
    
    Args:
        model_path: Path to the trained model
        val_dataloader: Validation dataloader
        contract_thresholds: Dictionary of thresholds for each vulnerability type at contract level
                           e.g., {'ARTHM': 0.2, 'DOS': 0.3, 'LE': 0.1, ...}
        line_thresholds: Dictionary of thresholds for each vulnerability type at line level
                        e.g., {'ARTHM': 0.2, 'DOS': 0.3, 'LE': 0.1, ...}
        max_contracts: Maximum number of contracts to evaluate (None for all)
        output_file: Output JSON file path
        
    Returns:
        Dictionary with comprehensive evaluation results
    """
    print("🚀 Starting comprehensive model evaluation...")
    
    # Initialize analyzer
    print("🔧 Setting up analyzer...")
    analyzer = SmartContractAnalyzer(
        model_path=model_path,
        device="cuda" if torch.cuda.is_available() else "cpu",
        use_gan=True
    )
    
    # Set default thresholds if not provided
    if contract_thresholds is None:
        contract_thresholds = {vuln_type: 0.2 for vuln_type in analyzer.vulnerability_types}
    if line_thresholds is None:
        line_thresholds = {vuln_type: 0.2 for vuln_type in analyzer.vulnerability_types}
    
    print(f"✅ Components initialized. Evaluating on {len(val_dataloader.dataset)} contracts")
    print(f"📋 Contract thresholds: {contract_thresholds}")
    print(f"📍 Line thresholds: {line_thresholds}")
    
    # Results storage
    results = {
        'model_info': {
            'model_path': model_path,
            'contract_thresholds': contract_thresholds,
            'line_thresholds': line_thresholds,
            'total_contracts': len(val_dataloader.dataset),
            'vulnerability_types': analyzer.vulnerability_types
        },
        'contract_level': {
            'overall': {},
            'per_vulnerability': {}
        },
        'line_level': {
            'overall': {},
            'per_vulnerability': {},
            'statistics': {
                'total_lines_processed': 0,
                'total_lines_with_vulnerabilities': 0,
                'total_lines_predicted_vulnerable': 0,
                'per_vulnerability': {}
            }
        }
    }
    
    # Collect validation results with the specified thresholds
    print(f"\n📊 Collecting validation results...")
    
    # We need to collect results for each vulnerability type with its specific threshold
    # This requires multiple passes through the data
    contract_true_all = []
    contract_pred_all = []
    contract_probs_all = []
    
    line_true_all = []
    line_pred_all = []
    line_probs_all = []
    
    # Process each contract
    total_contracts = len(val_dataloader.dataset) if max_contracts is None else min(max_contracts, len(val_dataloader.dataset))
    
    for contract_idx in range(total_contracts):
        try:
            # Get contract data
            contract_data = val_dataloader.dataset.data[contract_idx]
            source_code = contract_data['source_code']
            true_contract_vulns = contract_data['contract_vulnerabilities'].cpu().numpy()
            true_line_vulns = contract_data['vulnerable_lines'].cpu().numpy()  # Shape: (8, 1024)
            
            # Get vulnerability predictions using the regular method
            try:
                analyzer_results = analyzer.detect_vulnerabilities(
                    source_code, 
                    threshold=0.0  # Use 0.0 to get all probabilities, then apply thresholds later
                )
            except Exception as e:
                print(f"    Error in detect_vulnerabilities for contract {contract_idx}: {str(e)}")
                continue
            
            # Extract predictions
            pred_contract_vulns = analyzer_results['contract_vulnerabilities']
            pred_line_vulns = analyzer_results['line_vulnerabilities']
            pred_contract_probs = analyzer_results['contract_probabilities']
            pred_line_probs = analyzer_results['line_probabilities']
            
            # Debug: Print structure of pred_line_probs
            if contract_idx == 0:  # Only for first contract to avoid spam
                print(f"🔍 DEBUG: pred_line_probs structure:")
                print(f"  Type: {type(pred_line_probs)}")
                print(f"  Length: {len(pred_line_probs)}")
                if len(pred_line_probs) > 0:
                    print(f"  pred_line_probs[0] type: {type(pred_line_probs[0])}")
                    print(f"  pred_line_probs[0] length: {len(pred_line_probs[0])}")
                    if len(pred_line_probs[0]) > 0:
                        print(f"  pred_line_probs[0][0] type: {type(pred_line_probs[0][0])}")
                        print(f"  pred_line_probs[0][0] length: {len(pred_line_probs[0][0])}")
                        print(f"  Sample values: {pred_line_probs[0][0][:5]}")
            
            # Process contract-level predictions with specific thresholds
            contract_pred = np.zeros(len(analyzer.vulnerability_types))
            contract_probs = np.zeros(len(analyzer.vulnerability_types))
            
            for vuln_idx, vuln_type in enumerate(analyzer.vulnerability_types):
                if vuln_idx < len(pred_contract_probs[0]):
                    prob = pred_contract_probs[0][vuln_idx]
                    contract_probs[vuln_idx] = prob
                    # Apply specific threshold for this vulnerability type
                    contract_pred[vuln_idx] = 1 if prob > contract_thresholds[vuln_type] else 0
            
            contract_true_all.append(true_contract_vulns)
            contract_pred_all.append(contract_pred)
            contract_probs_all.append(contract_probs)
            
            # Process line-level predictions with specific thresholds
            lines = source_code.split('\n')
            actual_lines = len(lines)
            
            # Only process actual lines in the contract, not the full 1024
            line_pred = np.zeros((len(analyzer.vulnerability_types), actual_lines))
            line_probs = np.zeros((len(analyzer.vulnerability_types), actual_lines))
            
            # Handle the 3D structure of pred_line_probs: [batch][lines][vulnerabilities]
            # pred_line_probs is a list of shape [1][num_lines][num_vulnerabilities]
            if len(pred_line_probs) > 0 and len(pred_line_probs[0]) > 0:
                # pred_line_probs[0] has shape [num_lines][num_vulnerabilities]
                for line_idx in range(min(actual_lines, len(pred_line_probs[0]))):
                    for vuln_idx, vuln_type in enumerate(analyzer.vulnerability_types):
                        if vuln_idx < len(pred_line_probs[0][line_idx]):
                            prob = pred_line_probs[0][line_idx][vuln_idx]
                            line_probs[vuln_idx, line_idx] = prob
                            
                            # Filter out empty lines (probability 0.5000) and apply specific threshold
                            if prob != 0.5000:  # Not an empty line
                                line_pred[vuln_idx, line_idx] = 1 if prob > line_thresholds[vuln_type] else 0
                            else:
                                # Empty line - set prediction to 0 regardless of threshold
                                line_pred[vuln_idx, line_idx] = 0
            
            # Truncate true line vulnerabilities to match actual lines
            true_line_vulns_actual = true_line_vulns[:, :actual_lines] if true_line_vulns.shape[1] > actual_lines else true_line_vulns
            line_true_all.append(true_line_vulns_actual)
            line_pred_all.append(line_pred)
            line_probs_all.append(line_probs)
            
            if (contract_idx + 1) % 10 == 0:
                print(f"  ✅ Processed {contract_idx + 1}/{total_contracts} contracts")
                
        except Exception as e:
            print(f"  ⚠️  Error processing contract {contract_idx}: {str(e)}")
            continue
    
    # Convert to numpy arrays
    contract_true_all = np.array(contract_true_all)
    contract_pred_all = np.array(contract_pred_all)
    contract_probs_all = np.array(contract_probs_all)
    
    # Handle variable-length line arrays by flattening them
    # Since each contract has different number of lines, we need to flatten the results
    line_true_flat = []
    line_pred_flat = []
    line_probs_flat = []
    
    for i in range(len(line_true_all)):
        # Flatten each contract's line data
        for vuln_idx in range(line_true_all[i].shape[0]):
            for line_idx in range(line_true_all[i].shape[1]):
                line_true_flat.append(line_true_all[i][vuln_idx, line_idx])
                line_pred_flat.append(line_pred_all[i][vuln_idx, line_idx])
                line_probs_flat.append(line_probs_all[i][vuln_idx, line_idx])
    
    # Convert to numpy arrays
    line_true_flat = np.array(line_true_flat)
    line_pred_flat = np.array(line_pred_flat)
    line_probs_flat = np.array(line_probs_flat)
    
    print(f"✅ Collected data for {len(contract_true_all)} contracts")
    
    # Calculate line-level statistics using flattened arrays
    total_lines_processed = len(line_true_flat)  # Total number of line-vulnerability combinations
    total_lines_with_vulnerabilities = np.sum(line_true_flat > 0)  # Lines that actually have vulnerabilities
    total_lines_predicted_vulnerable = np.sum(line_pred_flat > 0)  # Lines predicted as vulnerable
    
    # Store statistics
    results['line_level']['statistics']['total_lines_processed'] = int(total_lines_processed)
    results['line_level']['statistics']['total_lines_with_vulnerabilities'] = int(total_lines_with_vulnerabilities)
    results['line_level']['statistics']['total_lines_predicted_vulnerable'] = int(total_lines_predicted_vulnerable)
    
    # Per-vulnerability statistics - need to reconstruct from flattened data
    # We need to track which vulnerability type each flattened element belongs to
    vuln_type_counts = {}
    vuln_type_vulnerable_counts = {}
    vuln_type_predicted_counts = {}
    
    # Initialize counters
    for vuln_type in analyzer.vulnerability_types:
        vuln_type_counts[vuln_type] = 0
        vuln_type_vulnerable_counts[vuln_type] = 0
        vuln_type_predicted_counts[vuln_type] = 0
    
    # Count per vulnerability type from the flattened data
    # Since we flattened by contract -> vulnerability -> line, we need to reconstruct
    current_idx = 0
    for contract_idx in range(len(line_true_all)):
        for vuln_idx in range(line_true_all[contract_idx].shape[0]):
            vuln_type = analyzer.vulnerability_types[vuln_idx]
            num_lines = line_true_all[contract_idx].shape[1]
            
            # Count this vulnerability type's lines
            vuln_type_counts[vuln_type] += num_lines
            
            # Count vulnerable and predicted lines for this contract and vulnerability type
            contract_vuln_true = line_true_all[contract_idx][vuln_idx, :]
            contract_vuln_pred = line_pred_all[contract_idx][vuln_idx, :]
            
            vuln_type_vulnerable_counts[vuln_type] += np.sum(contract_vuln_true > 0)
            vuln_type_predicted_counts[vuln_type] += np.sum(contract_vuln_pred > 0)
    
    # Store per-vulnerability statistics
    for vuln_type in analyzer.vulnerability_types:
        results['line_level']['statistics']['per_vulnerability'][vuln_type] = {
            'total_lines_processed': int(vuln_type_counts[vuln_type]),
            'total_lines_with_vulnerabilities': int(vuln_type_vulnerable_counts[vuln_type]),
            'total_lines_predicted_vulnerable': int(vuln_type_predicted_counts[vuln_type])
        }
    
    print(f"📊 Line-level statistics:")
    print(f"  Total lines processed: {total_lines_processed:,}")
    print(f"  Total lines with vulnerabilities: {total_lines_with_vulnerabilities:,}")
    print(f"  Total lines predicted as vulnerable: {total_lines_predicted_vulnerable:,}")
    
    # Compute contract-level metrics
    print(f"\n📋 Computing contract-level metrics...")
    
    # Overall contract-level metrics (aggregated across all vulnerability types)
    contract_true_flat = contract_true_all.flatten()
    contract_pred_flat = contract_pred_all.flatten()
    contract_probs_flat = contract_probs_all.flatten()
    
    results['contract_level']['overall'] = compute_metrics_for_threshold(
        contract_true_flat, contract_pred_flat, contract_probs_flat
    )
    
    # Per-vulnerability contract-level metrics
    for vuln_idx, vuln_type in enumerate(analyzer.vulnerability_types):
        if vuln_idx < contract_true_all.shape[1]:
            vuln_true = contract_true_all[:, vuln_idx]
            vuln_pred = contract_pred_all[:, vuln_idx]
            vuln_probs = contract_probs_all[:, vuln_idx]
            
            results['contract_level']['per_vulnerability'][vuln_type] = compute_metrics_for_threshold(
                vuln_true, vuln_pred, vuln_probs
            )
    
    # Compute line-level metrics
    print(f"\n📍 Computing line-level metrics...")
    
    # Overall line-level metrics (aggregated across all vulnerability types and lines)
    # Use the already flattened arrays
    # Calculate overall probability statistics
    true_positive_mask = (line_true_flat == 1) & (line_pred_flat == 1)
    should_be_vulnerable_mask = (line_true_flat == 1)
    
    # Debug information
    print(f"🔍 DEBUG: Line-level analysis:")
    print(f"  Total flattened elements: {len(line_true_flat)}")
    print(f"  Elements that should be vulnerable: {np.sum(should_be_vulnerable_mask)}")
    print(f"  True positives: {np.sum(true_positive_mask)}")
    print(f"  Sample of line_probs_flat: {line_probs_flat[:10]}")
    print(f"  Sample of should_be_vulnerable_mask: {should_be_vulnerable_mask[:10]}")
    if np.any(should_be_vulnerable_mask):
        print(f"  Probabilities for should_be_vulnerable lines: {line_probs_flat[should_be_vulnerable_mask][:10]}")
    
    mean_true_positive_prob = float(np.mean(line_probs_flat[true_positive_mask])) if np.any(true_positive_mask) else 0.0
    mean_should_be_vulnerable_prob = float(np.mean(line_probs_flat[should_be_vulnerable_mask])) if np.any(should_be_vulnerable_mask) else 0.0
    
    print(f"  Mean probability for should_be_vulnerable: {mean_should_be_vulnerable_prob:.6f}")
    print(f"  Mean probability for true_positives: {mean_true_positive_prob:.6f}")
    
    # Get basic metrics
    basic_metrics = compute_metrics_for_threshold(line_true_flat, line_pred_flat, line_probs_flat)
    
    # Add probability statistics
    results['line_level']['overall'] = {
        **basic_metrics,
        'mean_true_positive_probability': mean_true_positive_prob,
        'mean_should_be_vulnerable_probability': mean_should_be_vulnerable_prob,
        'num_true_positives': int(np.sum(true_positive_mask)),
        'num_should_be_vulnerable': int(np.sum(should_be_vulnerable_mask))
    }
    
    # Per-vulnerability line-level metrics - need to reconstruct from flattened data
    for vuln_idx, vuln_type in enumerate(analyzer.vulnerability_types):
        # Collect data for this vulnerability type from all contracts
        vuln_true = []
        vuln_pred = []
        vuln_probs = []
        
        for contract_idx in range(len(line_true_all)):
            if vuln_idx < line_true_all[contract_idx].shape[0]:
                vuln_true.extend(line_true_all[contract_idx][vuln_idx, :])
                vuln_pred.extend(line_pred_all[contract_idx][vuln_idx, :])
                vuln_probs.extend(line_probs_all[contract_idx][vuln_idx, :])
        
        # Convert to numpy arrays
        vuln_true = np.array(vuln_true)
        vuln_pred = np.array(vuln_pred)
        vuln_probs = np.array(vuln_probs)
        
        # Calculate mean prediction probability for true positive lines
        true_positive_mask = (vuln_true == 1) & (vuln_pred == 1)
        mean_true_positive_prob = float(np.mean(vuln_probs[true_positive_mask])) if np.any(true_positive_mask) else 0.0
        
        # Calculate mean prediction probability for all lines that should be predicted as vulnerable (true positive + false negative)
        should_be_vulnerable_mask = (vuln_true == 1)
        mean_should_be_vulnerable_prob = float(np.mean(vuln_probs[should_be_vulnerable_mask])) if np.any(should_be_vulnerable_mask) else 0.0
        
        # Debug for this vulnerability type
        if np.sum(should_be_vulnerable_mask) > 0:
            print(f"  🔍 {vuln_type}: {np.sum(should_be_vulnerable_mask)} should be vulnerable, mean prob: {mean_should_be_vulnerable_prob:.6f}")
            print(f"     Sample probs: {vuln_probs[should_be_vulnerable_mask][:5]}")
        
        # Get basic metrics
        basic_metrics = compute_metrics_for_threshold(vuln_true, vuln_pred, vuln_probs)
        
        # Add probability statistics
        results['line_level']['per_vulnerability'][vuln_type] = {
            **basic_metrics,
            'mean_true_positive_probability': mean_true_positive_prob,
            'mean_should_be_vulnerable_probability': mean_should_be_vulnerable_prob,
            'num_true_positives': int(np.sum(true_positive_mask)),
            'num_should_be_vulnerable': int(np.sum(should_be_vulnerable_mask))
        }
    
    # Save results
    print(f"\n💾 Saving results to {output_file}...")
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✅ Comprehensive evaluation completed! Results saved to {output_file}")
    
    return results

def print_evaluation_summary(results: Dict[str, Any]):
    """
    Print a summary of the evaluation results.
    
    Args:
        results: Evaluation results from evaluate_model_performance
    """
    print("\n" + "=" * 80)
    print("📊 COMPREHENSIVE EVALUATION SUMMARY")
    print("=" * 80)
    
    # Model info
    model_info = results['model_info']
    print(f"\n🔧 Model Information:")
    print(f"  Model Path: {model_info['model_path']}")
    print(f"  Contract Thresholds: {model_info['contract_thresholds']}")
    print(f"  Line Thresholds: {model_info['line_thresholds']}")
    print(f"  Total Contracts: {model_info['total_contracts']}")
    print(f"  Vulnerability Types: {model_info['vulnerability_types']}")
    
    # Contract-level results
    print(f"\n📋 CONTRACT-LEVEL RESULTS:")
    print("-" * 40)
    
    print(f"  Overall Performance:")
    for metric, value in results['contract_level']['overall'].items():
        print(f"    {metric.upper()}: {value:.4f}")
    
    print(f"\n  Per-Vulnerability Performance:")
    for vuln_type, metrics in results['contract_level']['per_vulnerability'].items():
        print(f"    {vuln_type}:")
        for metric, value in metrics.items():
            print(f"      {metric.upper()}: {value:.4f}")
    
    # Line-level results
    print(f"\n📍 LINE-LEVEL RESULTS:")
    print("-" * 40)
    
    print(f"  Overall Performance:")
    for metric, value in results['line_level']['overall'].items():
        if metric in ['mean_true_positive_probability', 'mean_should_be_vulnerable_probability']:
            print(f"    {metric}: {value:.4f}")
        elif metric in ['num_true_positives', 'num_should_be_vulnerable']:
            print(f"    {metric}: {value}")
        else:
            print(f"    {metric.upper()}: {value:.4f}")
    
    # Display line-level statistics
    stats = results['line_level']['statistics']
    print(f"\n  📊 LINE-LEVEL STATISTICS:")
    print(f"    Total lines processed: {stats['total_lines_processed']:,}")
    print(f"    Total lines with vulnerabilities: {stats['total_lines_with_vulnerabilities']:,}")
    print(f"    Total lines predicted as vulnerable: {stats['total_lines_predicted_vulnerable']:,}")
    
    print(f"\n  Per-Vulnerability Statistics:")
    for vuln_type, vuln_stats in stats['per_vulnerability'].items():
        print(f"    {vuln_type}:")
        print(f"      Lines processed: {vuln_stats['total_lines_processed']:,}")
        print(f"      Lines with vulnerabilities: {vuln_stats['total_lines_with_vulnerabilities']:,}")
        print(f"      Lines predicted as vulnerable: {vuln_stats['total_lines_predicted_vulnerable']:,}")
    
    print(f"\n  Per-Vulnerability Performance:")
    for vuln_type, metrics in results['line_level']['per_vulnerability'].items():
        print(f"    {vuln_type}:")
        for metric, value in metrics.items():
            if metric in ['mean_true_positive_probability', 'mean_should_be_vulnerable_probability']:
                print(f"      {metric}: {value:.4f}")
            elif metric in ['num_true_positives', 'num_should_be_vulnerable']:
                print(f"      {metric}: {value}")
            else:
                print(f"      {metric.upper()}: {value:.4f}")
    
    # Add a note about the statistics
    print(f"\n💡 NOTE: To get exact counts of processed lines, vulnerable lines, and predicted vulnerable lines,")
    print(f"   the evaluation function would need to be modified to track these statistics during processing.")


In [32]:
contract_thresholds = {
    'ARTHM': 0.5,
    'DOS': 0.4,
    'LE': 0.4,
    'RENT': 0.4,
    'TimeM': 0.4,
    'TimeO': 0.4,
    'Tx-Origin': 0.4,
    'UE': 0.4
}

line_thresholds = {
    'ARTHM': 0.2,
    'DOS': 0.2,
    'LE': 0.2,
    'RENT': 0.2,
    'TimeM': 0.2,
    'TimeO': 0.2,
    'Tx-Origin': 0.2,
    'UE': 0.2
}

results = evaluate_model_performance(
    model_path=MODEL_PATH,
    val_dataloader=val_dataloader,
    contract_thresholds=contract_thresholds,
    line_thresholds=line_thresholds,
    max_contracts=500,  # Evaluate first 50 contracts
    output_file="comprehensive_evaluation_results.json"
)

print_evaluation_summary(results) 

🚀 Starting comprehensive model evaluation...
🔧 Setting up analyzer...




DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights
Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt
Training epoch: 106
Best validation loss: 0.773994717746973
Training config: GAN=True
✅ Components initialized. Evaluating on 506 contracts
📋 Contract thresholds: {'ARTHM': 0.5, 'DOS': 0.4, 'LE': 0.4, 'RENT': 0.4, 'TimeM': 0.4, 'TimeO': 0.4, 'Tx-Origin': 0.4, 'UE': 0.4}
📍 Line thresholds: {'ARTHM': 0.2, 'DOS': 0.2, 'LE': 0.2, 'RENT': 0.2, 'TimeM': 0.2, 'TimeO': 0.2, 'Tx-Origin': 0.2, 'UE': 0.2}

📊 Collecting validation results...
🔍 DEBUG: pred_line_probs structure:
  Type: <class 'list'>
  Length: 1
  pred_line_probs[0] type: <class 'list'>
  pred_line_probs[0] length: 1024
  pred_line_probs[0][0] type: <class 'list'>
  pred_line_probs[0][0] length: 8
  Sample values: [1.1084030120400712e

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


  ✅ Processed 480/500 contracts
  ✅ Processed 490/500 contracts
  ✅ Processed 500/500 contracts
✅ Collected data for 500 contracts
📊 Line-level statistics:
  Total lines processed: 401,016
  Total lines with vulnerabilities: 2,174
  Total lines predicted as vulnerable: 2,566

📋 Computing contract-level metrics...

📍 Computing line-level metrics...
🔍 DEBUG: Line-level analysis:
  Total flattened elements: 401016
  Elements that should be vulnerable: 2174
  True positives: 1403
  Sample of line_probs_flat: [1.10840301e-05 1.83676381e-08 3.51013398e-07 6.64268737e-04
 9.83592749e-01 9.87356424e-01 1.45548191e-02 5.05667813e-05
 1.87427617e-07 1.06397044e-06]
  Sample of should_be_vulnerable_mask: [False False False False  True  True False False False False]
  Probabilities for should_be_vulnerable lines: [0.98359275 0.98735642 0.97724676 0.9405756  0.40342605 0.44837305
 0.22966476 0.49984008 0.5        0.5       ]
  Mean probability for should_be_vulnerable: 0.562761
  Mean probability f