In [4]:
!pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
from inference import *

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging

from transformers import AutoTokenizer
import json
import os
import pandas as pd
from typing import Dict, List, Tuple, Any
import re
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score

  _torch_pytree._register_pytree_node(


# 1. Load Validation Dataset:

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to codeBert input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

class SmartContractVulnerabilityDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 1024,
        split: str = "train",
        vulnerability_types: List[str] = None
    ):
        """
        Args:
            data_path: Path to the CSV file containing the dataset
            tokenizer: Tokenizer for encoding the source code
            max_length: Maximum sequence length
            split: "train" or "val" to specify which split to load
            vulnerability_types: List of vulnerability types to consider
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        self.vulnerability_types = vulnerability_types or [
            'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
        ]
        
        # Load the dataset
        self.data = self._load_dataset(data_path)
        
    def _load_dataset(self, data_path: str) -> List[Dict]:
        """Load and preprocess the dataset from CSV"""
        dataset = []
        
        # Read the CSV file
        df = pd.read_csv(data_path)
        
        # Split into train/val if needed
        if self.split == "train":
            df = df.sample(frac=0.8, random_state=42)
        else:
            df = df.sample(frac=0.2, random_state=42)
        
        # Process each contract
        for _, row in df.iterrows():
            try:
                source_code = row['source_code']
                contract_name = row['contract_name']
                
                # Parse AST and get paths
                ast = parse_solidity_to_ast(source_code)
                ast_paths = prepare_code2vec_input(ast) if ast else []
                ast_path_text = ' '.join(ast_paths)
                
                # Split source code into lines
                lines = source_code.split('\n')
                
                # Create token-to-line mapping
                token_to_line = []
                current_line = 0
                
                # Tokenize each line separately to maintain mapping
                for line in lines:
                    line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
                    token_to_line.extend([current_line] * len(line_tokens))
                    current_line += 1
                
                # Add special tokens
                token_to_line = [0] + token_to_line + [0]  # [CLS] and [SEP] tokens
                
                # Truncate if too long
                if len(token_to_line) > self.max_length:
                    token_to_line = token_to_line[:self.max_length]
                
                # Pad if too short
                if len(token_to_line) < self.max_length:
                    token_to_line.extend([0] * (self.max_length - len(token_to_line)))
                
                # Create multi-label line labels for each vulnerability type
                line_labels = self._create_multi_label_line_labels(source_code, row)
                
                # Create contract-level vulnerability labels
                contract_labels = self._create_contract_vulnerability_labels(row)
                
                # Tokenize the source code
                encoding = self.tokenizer(
                    source_code,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Tokenize AST paths
                ast_encoding = self.tokenizer(
                    ast_path_text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Convert line labels to tensor and ensure consistent shape
                vuln_tensor = torch.zeros((len(self.vulnerability_types), self.max_length), dtype=torch.long)
                for i, labels in enumerate(line_labels):
                    if len(labels) > self.max_length:
                        labels = labels[:self.max_length]
                    vuln_tensor[i, :len(labels)] = torch.tensor(labels, dtype=torch.long)
                
                # Convert contract labels to tensor
                contract_vuln_tensor = torch.tensor(contract_labels, dtype=torch.long)
                
                # Convert token_to_line to tensor
                token_to_line_tensor = torch.tensor(token_to_line, dtype=torch.long)
                
                # Ensure attention masks are boolean
                attention_mask = encoding['attention_mask'].squeeze(0).bool()
                ast_attention_mask = ast_encoding['attention_mask'].squeeze(0).bool()
                
                # Ensure input_ids are the right length
                input_ids = encoding['input_ids'].squeeze(0)
                ast_input_ids = ast_encoding['input_ids'].squeeze(0)
                
                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                if len(ast_input_ids) > self.max_length:
                    ast_input_ids = ast_input_ids[:self.max_length]
                
                # Pad if necessary
                if len(input_ids) < self.max_length:
                    input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - len(input_ids)))
                if len(ast_input_ids) < self.max_length:
                    ast_input_ids = torch.nn.functional.pad(ast_input_ids, (0, self.max_length - len(ast_input_ids)))
                
                dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'ast_input_ids': ast_input_ids,
                    'ast_attention_mask': ast_attention_mask,
                    'vulnerable_lines': vuln_tensor,
                    'contract_vulnerabilities': contract_vuln_tensor,
                    'token_to_line': token_to_line_tensor,
                    'source_code': source_code,
                    'contract_name': contract_name
                })
            except Exception as e:
                print(f"Error processing contract {contract_name}: {str(e)}")
                continue
        
        return dataset
    
    def _create_contract_vulnerability_labels(self, row: pd.Series) -> List[int]:
        """Create contract-level vulnerability labels"""
        contract_labels = []
        for vuln_type in self.vulnerability_types:
            # Check if contract has this vulnerability type
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Contract is vulnerable if it has any vulnerable lines
            has_vulnerability = len(vuln_lines) > 0
            contract_labels.append(1 if has_vulnerability else 0)
        
        return contract_labels
    
    def _create_multi_label_line_labels(self, source_code: str, row: pd.Series) -> List[List[int]]:
        """Create multi-label line labels for each vulnerability type"""
        total_lines = len(source_code.split('\n'))
        line_labels = {vuln_type: [0] * total_lines for vuln_type in self.vulnerability_types}
        
        # Process each vulnerability type
        for vuln_type in self.vulnerability_types:
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Process each vulnerable line/snippet
            for line_or_snippet in vuln_lines:
                if isinstance(line_or_snippet, int):
                    # If it's a line number, mark that line
                    if 0 <= line_or_snippet < total_lines:
                        line_labels[vuln_type][line_or_snippet] = 1
                else:
                    # If it's a code snippet, find matching lines
                    source_lines = source_code.split('\n')
                    for i, line in enumerate(source_lines):
                        # Clean both the line and snippet for comparison
                        clean_line = re.sub(r'\s+', ' ', line.strip())
                        clean_snippet = re.sub(r'\s+', ' ', str(line_or_snippet).strip())
                        if clean_snippet in clean_line:
                            line_labels[vuln_type][i] = 1
        
        # Convert to list format
        return [line_labels[vuln_type] for vuln_type in self.vulnerability_types]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable length inputs
    """
    # Get the maximum length in this batch for each type of tensor
    max_input_len = max(item['input_ids'].size(0) for item in batch)
    
    # Pad all tensors to their respective maximum lengths
    padded_batch = {
        'input_ids': torch.stack([
            torch.nn.functional.pad(item['input_ids'], (0, max_input_len - item['input_ids'].size(0)))
            for item in batch
        ]),
        'attention_mask': torch.stack([
            torch.nn.functional.pad(item['attention_mask'], (0, max_input_len - item['attention_mask'].size(0)))
            for item in batch
        ]),
        'ast_input_ids': torch.stack([item['ast_input_ids'] for item in batch]),
        'ast_attention_mask': torch.stack([item['ast_attention_mask'] for item in batch]),
        'vulnerable_lines': torch.stack([item['vulnerable_lines'] for item in batch]),
        'contract_vulnerabilities': torch.stack([item['contract_vulnerabilities'] for item in batch]),
        'token_to_line': torch.stack([item['token_to_line'] for item in batch]),
        'source_code': [item['source_code'] for item in batch],
        'contract_name': [item['contract_name'] for item in batch]
    }
    
    return padded_batch

def create_dataloaders(
    data_path: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 1024,
    num_workers: int = 4,
    vulnerability_types: List[str] = None
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Create train and validation dataloaders
    
    Args:
        data_path: Path to the CSV file containing the dataset
        tokenizer: Tokenizer for encoding the source code
        batch_size: Batch size for training
        max_length: Maximum sequence length
        num_workers: Number of workers for data loading
        vulnerability_types: List of vulnerability types to consider
    
    Returns:
        Tuple of (train_dataloader, val_dataloader)
    """
    # Create datasets
    train_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="train",
        vulnerability_types=vulnerability_types
    )
    
    val_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="val",
        vulnerability_types=vulnerability_types
    )
    
    # Create dataloaders with custom collate function
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    return train_dataloader, val_dataloader

In [3]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

# Create dataloaders
train_dataloader, val_dataloader = create_dataloaders(
    data_path="contract_sources_with_vulnerabilities_2048_token_size.csv",
    tokenizer=tokenizer,
    batch_size=8,
    max_length=1024,
    vulnerability_types=['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


In [4]:
val_dataloader.dataset.data[0]['vulnerable_lines'][7]

tensor([0, 0, 0,  ..., 0, 0, 0])

# 2. Load Model and SmartContractAnalyzer

In [5]:
def calculate_precision(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate precision for vulnerability detection."""
    if np.sum(pred_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(pred_labels)

def calculate_recall(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate recall for vulnerability detection."""
    if np.sum(true_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(true_labels)

def calculate_f1_score(precision: float, recall: float) -> float:
    """Calculate F1 score."""
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)

def calculate_line_accuracy(true_line_vulns: np.ndarray, pred_line_vulns: Dict) -> float:
    """Calculate line-level accuracy (simplified)."""
    try:
        # Convert predicted line vulnerabilities to array format
        pred_array = np.zeros_like(true_line_vulns)
        for line_num, line_vulns in pred_line_vulns.items():
            if line_num < pred_array.shape[0]:
                for vuln_idx, is_vuln in enumerate(line_vulns.values()):
                    if vuln_idx < pred_array.shape[1]:
                        pred_array[line_num, vuln_idx] = 1 if is_vuln else 0
        
        return np.mean(pred_array == true_line_vulns)
    except:
        return 0.

def get_vulnerability_details(analyzer, true_vulns: np.ndarray, pred_vulns: np.ndarray, 
                            probabilities: List[float]) -> Dict[str, Any]:
    """Get detailed vulnerability analysis."""
    details = {}
    for i, vuln_type in enumerate(analyzer.vulnerability_types):
        details[vuln_type] = {
            'true_positive': bool(true_vulns[i] == 1 and pred_vulns[i] == 1),
            'false_positive': bool(true_vulns[i] == 0 and pred_vulns[i] == 1),
            'false_negative': bool(true_vulns[i] == 1 and pred_vulns[i] == 0),
            'true_negative': bool(true_vulns[i] == 0 and pred_vulns[i] == 0),
            'probability': probabilities[i],
            'true_label': int(true_vulns[i]),
            'predicted_label': int(pred_vulns[i])
        }
    return details
    

## 1 Manual one by one:

In [6]:
model_path="checkpoints_v5_2048_output/best_model_augmented_gan_epoch_91.pt"
tokenizer_name = "microsoft/codebert-base"
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token



In [7]:
model = SmartContractTransformer(
        d_model=768,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        max_length=1024,
        vocab_size=tokenizer.vocab_size,
        num_vulnerability_types=8,
        use_gan=True 
    )

checkpoint = torch.load(model_path, map_location=device)



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights


In [8]:
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from {model_path}")
    print(f"Training epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"Best validation loss: {checkpoint.get('val_loss', 'Unknown')}")
    print(f"Training config: GAN={checkpoint.get('use_gan', 'Unknown')}")
else:
    # Direct state dict
    model.load_state_dict(checkpoint)
    print(f"Model loaded from {model_path}")

Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_91.pt
Training epoch: 91
Best validation loss: 0.83826167229563
Training config: GAN=True


In [9]:
model.to(device)
model.eval()
    
vulnerability_types = [
    'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
]

In [10]:
val_dataloader.dataset.data[0]['source_code']

"pragma solidity ^0.4.19;\r\n\r\ncontract BaseToken {\r\n    string public name;\r\n    string public symbol;\r\n    uint8 public decimals;\r\n    uint256 public totalSupply;\r\n\r\n    mapping (address => uint256) public balanceOf;\r\n    mapping (address => mapping (address => uint256)) public allowance;\r\n\r\n    event Transfer(address indexed from, address indexed to, uint256 value);\r\n    event Approval(address indexed owner, address indexed spender, uint256 value);\r\n\r\n    function _transfer(address _from, address _to, uint _value) internal {\r\n        require(_to != 0x0);\r\n        require(balanceOf[_from] >= _value);\r\n        require(balanceOf[_to] + _value > balanceOf[_to]);\r\n        uint previousBalances = balanceOf[_from] + balanceOf[_to];\r\n        balanceOf[_from] -= _value;\r\n        balanceOf[_to] += _value;\r\n        assert(balanceOf[_from] + balanceOf[_to] == previousBalances);\r\n        Transfer(_from, _to, _value);\r\n    }\r\n\r\n    function transfer

In [11]:
contract_code = val_dataloader.dataset.data[0]['source_code']
ast = parse_solidity_to_ast(contract_code)
ast_paths = prepare_code2vec_input(ast) if ast else []
ast_path_text = ' '.join(ast_paths)

# Tokenize inputs
contract_encoding = tokenizer(
    contract_code,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

ast_encoding = tokenizer(
    ast_path_text,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

In [12]:
input_ids = contract_encoding['input_ids'].to(device)
attention_mask = contract_encoding['attention_mask'].to(device)
ast_input_ids = ast_encoding['input_ids'].to(device)
ast_attention_mask = ast_encoding['attention_mask'].to(device)

# Create proper token-to-line mapping that matches the actual tokenization
lines = contract_code.split('\n')
token_to_line = []
current_line = 0

# Add [CLS] token mapping (line 0)
token_to_line.append(0)

for line in lines:
    line_tokens = tokenizer.encode(line, add_special_tokens=False)
    # Map all tokens in this line to the same line number
    token_to_line.extend([current_line] * len(line_tokens))
    current_line += 1

# Add [SEP] token mapping (line 0)
token_to_line.append(0)

# Ensure the mapping matches the tokenized length (1024)
if len(token_to_line) > 1024:
    token_to_line = token_to_line[:1024]
elif len(token_to_line) < 1024:
    token_to_line.extend([0] * (1024 - len(token_to_line)))

token_to_line = torch.tensor(token_to_line, dtype=torch.long).to(device)


In [13]:
line_tokens = tokenizer.encode(lines[5], add_special_tokens=False)
line_tokens

[1437, 1437, 1437, 49315, 398, 285, 5044, 757, 1536, 131, 50121]

In [14]:
outputs = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    ast_input_ids=ast_input_ids,
    ast_attention_mask=ast_attention_mask,
    target_ids=input_ids,  # Use input_ids as target for inference
    token_to_line=token_to_line
)

contract_vuln_logits = outputs.get('contract_vulnerability_logits')
line_vuln_logits = outputs.get('line_vulnerability_logits')



In [15]:
# Get line-level vulnerability probabilities
line_vuln_probs = torch.sigmoid(line_vuln_logits)  # Shape: [1, num_lines, 8]

# For each line, check if it's vulnerable
for line_idx in range(line_vuln_probs.shape[1]):
    line_probs = line_vuln_probs[0, line_idx, :]  # Shape: [8]
    # Check if any vulnerability type has high probability
    is_vulnerable = (line_probs > 0.1).any()
    if is_vulnerable:
        print(f"Line {line_idx} is vulnerable: {line_probs}")
        
#[1.8209e-01, 1.2123e-02, 1.1189e-05, 1.1246e-05, 3.1900e-02, 1.8309e-02, 1.1306e-05, 1.3287e-02]

Line 5 is vulnerable: tensor([2.0804e-01, 5.4823e-03, 1.6645e-05, 4.4838e-05, 1.8898e-03, 2.6380e-03,
        3.6004e-05, 9.2746e-03], device='cuda:0', grad_fn=<SliceBackward0>)
Line 6 is vulnerable: tensor([2.8489e-01, 4.1994e-03, 1.1316e-05, 3.1541e-05, 1.3289e-03, 8.9649e-04,
        1.6679e-05, 2.7547e-03], device='cuda:0', grad_fn=<SliceBackward0>)
Line 7 is vulnerable: tensor([2.7787e-01, 3.7919e-03, 2.0019e-05, 5.4323e-05, 7.0747e-04, 8.6934e-04,
        1.9290e-05, 1.5360e-03], device='cuda:0', grad_fn=<SliceBackward0>)
Line 8 is vulnerable: tensor([1.3494e-01, 8.2156e-03, 4.1825e-05, 9.2689e-05, 4.3523e-04, 2.5948e-03,
        3.3628e-05, 2.1155e-03], device='cuda:0', grad_fn=<SliceBackward0>)
Line 81 is vulnerable: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
       device='cuda:0', grad_fn=<SliceBackward0>)
Line 82 is vulnerable: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
       device='cuda:0', grad_fn=<SliceBackward

In [18]:
# Get line-level vulnerability probabilities
line_vuln_probs = torch.sigmoid(line_vuln_logits)  # Shape: [1, num_lines, 8]

# Check for uniform predictions across lines
print("=== LINE-LEVEL VULNERABILITY ANALYSIS ===")
print(f"Total lines: {line_vuln_probs.shape[1]}")
print(f"Vulnerability types: {line_vuln_probs.shape[2]}")

# For each vulnerability type, check if all lines have the same probability
for vuln_type in range(line_vuln_probs.shape[2]):
    vuln_probs = line_vuln_probs[0, :, vuln_type]  # Shape: [num_lines]
    
    # Check if all probabilities are the same
    min_prob = vuln_probs.min().item()
    max_prob = vuln_probs.max().item()
    mean_prob = vuln_probs.mean().item()
    std_prob = vuln_probs.std().item()
    
    print(f"\nVulnerability Type {vuln_type}:")
    print(f"  Min: {min_prob:.4f}")
    print(f"  Max: {max_prob:.4f}")
    print(f"  Mean: {mean_prob:.4f}")
    print(f"  Std: {std_prob:.4f}")
    
    # Check if all lines have the same probability (within small tolerance)
    if std_prob < 1e-6:
        print(f"  ⚠️  WARNING: All lines have the same probability ({mean_prob:.4f})")
        print(f"  This suggests the model is not learning line-specific detection!")
    elif std_prob < 0.01:
        print(f"  ⚠️  WARNING: Very low variance ({std_prob:.6f}) - model may not be learning line differences")
    else:
        print(f"  ✅ Good variance ({std_prob:.6f}) - model is learning line-specific detection")

# Overall statistics
print(f"\n=== OVERALL STATISTICS ===")
print(f"All probabilities range: [{line_vuln_probs.min().item():.4f}, {line_vuln_probs.max().item():.4f}]")
print(f"Overall mean: {line_vuln_probs.mean().item():.4f}")
print(f"Overall std: {line_vuln_probs.std().item():.4f}")

# Count vulnerable lines
vulnerable_lines = 0
for line_idx in range(line_vuln_probs.shape[1]):
    line_probs = line_vuln_probs[0, line_idx, :]  # Shape: [8]
    is_vulnerable = (line_probs > 0.2).any() and (line_probs != 0.50000).all()
    if is_vulnerable:
        vulnerable_lines += 1
        print(f"Line {line_idx} is vulnerable: {line_probs}")

print(f"\nTotal vulnerable lines: {vulnerable_lines}/{line_vuln_probs.shape[1]}")

=== LINE-LEVEL VULNERABILITY ANALYSIS ===
Total lines: 1024
Vulnerability types: 8

Vulnerability Type 0:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4617
  Std: 0.1316
  ✅ Good variance (0.131603) - model is learning line-specific detection

Vulnerability Type 1:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4606
  Std: 0.1345
  ✅ Good variance (0.134514) - model is learning line-specific detection

Vulnerability Type 2:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4604
  Std: 0.1350
  ✅ Good variance (0.135013) - model is learning line-specific detection

Vulnerability Type 3:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4605
  Std: 0.1350
  ✅ Good variance (0.135012) - model is learning line-specific detection

Vulnerability Type 4:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4606
  Std: 0.1347
  ✅ Good variance (0.134663) - model is learning line-specific detection

Vulnerability Type 5:
  Min: 0.0000
  Max: 0.5000
  Mean: 0.4605
  Std: 0.1347
  ✅ Good variance (0.134705) - model is learning line-specific detection

In [None]:


def parse_solidity_to_ast(self, code: str) -> Dict[str, Any]:
    """Parse Solidity code into a simplified AST structure"""
    def extract_contract_info(code: str) -> Dict[str, Any]:
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)
        code = re.sub(r'\s+', ' ', code)
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(self, ast: Dict[str, Any]) -> List[str]:
    """Convert AST to codeBert input format"""
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        if 'name' in node:
            current_path.append(node['name'])
            
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

In [None]:
analyzer = SmartContractAnalyzer(
    model_path="checkpoints_v5_2048_output/best_model_augmented_gan_epoch_13.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
    use_gan=True
)

In [6]:
# Import the utility functions
from notebook_utils import (
    collect_validation_results_simple,
    collect_validation_results,
    compute_contract_level_metrics,
    print_simplified_validation_summary,
    generate_syntax_aware_contract,
    compute_line_level_metrics
)

# Initialize analyzer with latest checkpoint
analyzer = SmartContractAnalyzer(
    model_path="checkpoints_v5_2048_output/best_model_augmented_gan_epoch_13.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
    use_gan=True
)

# Step 1: Collect comprehensive validation results (including line-level)
print("🚀 Starting comprehensive validation analysis...")
validation_results = collect_validation_results(
    analyzer=analyzer,
    val_dataloader=val_dataloader,
    threshold=0.5,
    max_contracts=10,  # Process first 1000 contracts
    generate_contracts=True  # Enable syntax-aware generation
)

# Step 2: Compute contract-level PR AUC and performance metrics
print("\n�� Computing contract-level performance metrics...")
contract_metrics = compute_contract_level_metrics(validation_results)

# Step 3: Compute line-level PR AUC and recall metrics
print("\n📈 Computing line-level performance metrics...")
line_metrics = compute_line_level_metrics(validation_results)

# Step 4: Print comprehensive summary
print("=" * 60)
print("📊 COMPREHENSIVE VALIDATION SUMMARY")
print("=" * 60)

# Contract-level performance
print(f"\n📈 Contract-Level Performance:")
print(f"  Overall PR-AUC: {contract_metrics['contract_level']['overall_pr_auc']:.4f}")
print(f"  Overall Accuracy: {contract_metrics['contract_level']['overall_accuracy']:.4f}")

print(f"  Per-Vulnerability Performance:")
for vuln_type in validation_results['metadata']['vulnerability_types']:
    pr_auc = contract_metrics['contract_level']['pr_auc'].get(vuln_type, 0.0)
    accuracy = contract_metrics['contract_level']['accuracy'].get(vuln_type, 0.0)
    print(f"    {vuln_type}: PR-AUC={pr_auc:.4f}, Accuracy={accuracy:.4f}")

# Line-level performance
print(f"\n📊 Line-Level Performance:")
for vuln, vals in line_metrics.items():
    print(f"  {vuln}: PR-AUC={vals['pr_auc']:.4f}, Recall={vals['vuln_line_recall']:.2%} "
          f"({vals['num_correct_vuln_lines']}/{vals['num_true_vuln_lines']} vulnerable lines)")

print("=" * 60)

# Step 5: Detailed inspection of results
print("\n🔍 DETAILED INSPECTION")
print("=" * 60)

# Inspect contract-level results
print(f"\n📊 Contract-Level Results:")
print(f"Total contracts processed: {validation_results['metadata']['total_contracts']}")
print(f"Contract-level PR-AUC: {contract_metrics['contract_level']['overall_pr_auc']:.4f}")
print(f"Contract-level Accuracy: {contract_metrics['contract_level']['overall_accuracy']:.4f}")

# Inspect generation results
print(f"\n�� Generation Results:")
print(f"Generation success rate: {validation_results['metadata']['generation_success_rate']:.2%}")

# Inspect top contracts by vulnerability score
print(f"\n�� Top 5 Contracts by Vulnerability Score:")
contract_probs = np.array(validation_results['contract_level']['predicted_probs'])
vulnerability_scores = np.sum(contract_probs, axis=1)
top_indices = np.argsort(vulnerability_scores)[-5:][::-1]

for i, contract_idx in enumerate(top_indices):
    contract_name = validation_results['contract_level']['contract_names'][contract_idx]
    source_code = validation_results['contract_level']['source_codes'][contract_idx]
    generated_code = validation_results['contract_level']['generated_codes'][contract_idx]
    vuln_score = vulnerability_scores[contract_idx]
    
    print(f"\nRank {i+1}: {contract_name}")
    print(f"  Vulnerability Score: {vuln_score:.4f}")
    print(f"  Source Length: {len(source_code)} characters")
    print(f"  Generated Length: {len(generated_code) if generated_code else 0} characters")
    print(f"  Generation Success: {generated_code is not None and generated_code != 'Generation failed'}")

# Step 6: Save results
import json
import pickle

# Save comprehensive results
results_summary = {
    'validation_results': validation_results,
    'contract_metrics': contract_metrics,
    'line_metrics': line_metrics,
    'metadata': {
        'model_path': "checkpoints_v5_2048_output/best_model_augmented_gan_epoch_148.pt",
        'threshold': 0.5,
        'total_contracts': validation_results['metadata']['total_contracts'],
        'processing_time': validation_results['metadata']['processing_time']
    }
}

# Save as pickle for full data
with open('validation_analysis_comprehensive.pkl', 'wb') as f:
    pickle.dump(results_summary, f)

# Save as JSON for easy inspection
json_results = {
    'contract_metrics': {
        'overall_pr_auc': float(contract_metrics['contract_level']['overall_pr_auc']),
        'overall_accuracy': float(contract_metrics['contract_level']['overall_accuracy']),
        'per_vulnerability': {
            vuln_type: {
                'pr_auc': float(contract_metrics['contract_level']['pr_auc'][vuln_type]),
                'accuracy': float(contract_metrics['contract_level']['accuracy'][vuln_type])
            }
            for vuln_type in validation_results['metadata']['vulnerability_types']
        }
    },
    'line_metrics': {
        vuln_type: {
            'pr_auc': float(metrics['pr_auc']),
            'vuln_line_recall': float(metrics['vuln_line_recall']),
            'num_true_vuln_lines': int(metrics['num_true_vuln_lines']),
            'num_correct_vuln_lines': int(metrics['num_correct_vuln_lines'])
        }
        for vuln_type, metrics in line_metrics.items()
    },
    'generation_success_rate': validation_results['metadata']['generation_success_rate'],
    'metadata': results_summary['metadata']
}

with open('validation_analysis_comprehensive.json', 'w') as f:
    json.dump(json_results, f, indent=2)

print(f"\n✅ Results saved to:")
print(f"  - validation_analysis_comprehensive.pkl (complete data)")
print(f"  - validation_analysis_comprehensive.json (summary metrics)")



Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_13.pt
Training epoch: 13
Best validation loss: 1.8405638737604022
Training config: GAN=True
🚀 Starting comprehensive validation analysis...
🚀 Starting comprehensive validation analysis...
📊 Processing 10 contracts...
    Contract 0: true_line_vulns shape: (8, 1024)
    Contract 0: actual_lines=90
    Contract 0: pred_line_probs_array shape: (8, 1024)
    Contract 0: pred_line_labels_array shape: (8, 1024)
    Contract 0: found 4 vulnerable lines
🎯 Generating 1 contract(s) with syntax-aware generation...
📊 Parameters: temperature=0.9, max_length=1024
Generating contract 1/1...
✓ Using syntax-aware generation with built-in constraints
Error generating contract 0: stack expects each tensor to be equal size, but got [1024, 768] at entry 0 and [768] at entry 1
Using template-based generation...
Successfully generated contract 1 with template method after error
✅ Successfully generated 1 contract(s)
  Contract 1: 322

Traceback (most recent call last):
  File "/home/m20180848/smrt-transformer/07.training-model/inference.py", line 904, in generate_synthetic_contract
    outputs = self.model(
  File "/home/m20180848/.conda/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/m20180848/.conda/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/m20180848/smrt-transformer/07.training-model/model.py", line 420, in forward
    lines = torch.stack(lines, dim=0)  # [num_lines, d_model]
RuntimeError: stack expects each tensor to be equal size, but got [1024, 768] at entry 0 and [768] at entry 1
Traceback (most recent call last):
  File "/home/m20180848/smrt-transformer/07.training-model/inference.py", line 904, in generate_synthetic_contract
    outputs = self.model(
  File "/home/m20180848/.con

Error generating contract 0: stack expects each tensor to be equal size, but got [1024, 768] at entry 0 and [768] at entry 1
Using template-based generation...
Successfully generated contract 1 with template method after error
✅ Successfully generated 1 contract(s)
  Contract 1: 3108 characters, valid syntax
    Contract 5: true_line_vulns shape: (8, 1024)
    Contract 5: actual_lines=211
    Contract 5: pred_line_probs_array shape: (8, 1024)
    Contract 5: pred_line_labels_array shape: (8, 1024)
    Contract 5: found 5 vulnerable lines
🎯 Generating 1 contract(s) with syntax-aware generation...
📊 Parameters: temperature=0.9, max_length=1024
Generating contract 1/1...
✓ Using syntax-aware generation with built-in constraints
Error generating contract 0: stack expects each tensor to be equal size, but got [1024, 768] at entry 0 and [768] at entry 1
Using template-based generation...
Successfully generated contract 1 with template method after error
✅ Successfully generated 1 contract(s)

Traceback (most recent call last):
  File "/home/m20180848/smrt-transformer/07.training-model/inference.py", line 904, in generate_synthetic_contract
    outputs = self.model(
  File "/home/m20180848/.conda/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/m20180848/.conda/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/m20180848/smrt-transformer/07.training-model/model.py", line 420, in forward
    lines = torch.stack(lines, dim=0)  # [num_lines, d_model]
RuntimeError: stack expects each tensor to be equal size, but got [1024, 768] at entry 0 and [768] at entry 1
Traceback (most recent call last):
  File "/home/m20180848/smrt-transformer/07.training-model/inference.py", line 904, in generate_synthetic_contract
    outputs = self.model(
  File "/home/m20180848/.con

In [7]:
from notebook_utils import (
    debug_line_predictions,
    analyze_vulnerable_line_probabilities,
    check_model_line_predictions,
    print_probability_analysis
)

debug_line_predictions(validation_results, contract_idx=0)

🔍 Debugging line predictions for contract 0...
📊 Contract 0 shapes:
  True labels: (8, 1024)
  Pred labels: (8, 1024)
  Pred probs: (8, 1024)
📏 Actual lines in contract: 90

🔍 Vulnerability Analysis:
  ARTHM:
    True vulnerable lines: 3
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
    Vulnerable line indices: [ 4  5 19]
      Line 4: prob=0.0000, pred=0.0
      Line 5: prob=0.0000, pred=0.0
      Line 19: prob=0.0000, pred=0.0
  DOS:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  LE:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  RENT:
    True vulnerable lines: 0
    Predicted vulnerable lines: 0.0
    Max probability: 0.0000
    Min probability: 0.0000
    Mean probability: 0.0000
  TimeM:
    True vul

In [8]:
prob_analysis = analyze_vulnerable_line_probabilities(validation_results)
print_probability_analysis(prob_analysis)

📊 Analyzing probability scores for vulnerable lines...
✅ Probability analysis completed!
📊 VULNERABLE LINE PROBABILITY ANALYSIS

📈 Overall Statistics:
  Total Vulnerable Lines: 42
  Mean Probability (Vulnerable Lines): 0.0000
  Mean Probability (All Lines): 0.0000

🎯 Probability Distribution (Vulnerable Lines):
  High Confidence (>0.8): 0 (0.0%)
  Medium Confidence (0.5-0.8): 0 (0.0%)
  Low Confidence (<0.5): 42 (100.0%)

🔍 Per-Vulnerability Type Analysis:
  ARTHM:
    Vulnerable Lines: 41
    Mean Probability (Vulnerable): 0.0000
    Max Probability (Vulnerable): 0.0000
    Min Probability (Vulnerable): 0.0000
    Mean Probability (All Lines): 0.0000

  DOS:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  LE:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  RENT:
    Vulnerable Lines: 0
    Mean Probability (All Lines): 0.0000

  TimeM:
    Vulnerable Lines: 1
    Mean Probability (Vulnerable): 0.0000
    Max Probability (Vulnerable): 0.0000
    

In [9]:
sample_contract = val_dataloader.dataset.data[0]['source_code']
check_model_line_predictions(analyzer, sample_contract)

🔍 Testing model line-level predictions on sample contract...
Contract vuln logits shape: torch.Size([1, 8])
Line vuln logits shape: torch.Size([1, 1024, 8])
Number of lines in contract: 90
Contract predictions shape: (1, 8)
Line predictions shape: (1, 1024, 8)
📊 Model output keys: ['contract_vulnerabilities', 'line_vulnerabilities', 'contract_probabilities', 'line_probabilities']
📋 Line vulnerabilities type: <class 'dict'>
📋 Number of lines with predictions: 90
  Line 0: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 1: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 2: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False}
  Line 3: {'ARTHM': False, 'DOS': False, 'LE': False, 'RENT': False, 'TimeM': False, 'TimeO': False, 'Tx-Origin': False, 'UE': False