In [4]:
!pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
from inference import *

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging

from transformers import AutoTokenizer
import json
import os
import pandas as pd
from typing import Dict, List, Tuple, Any
import re
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score

  _torch_pytree._register_pytree_node(


# 0. AUX Functions:

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to codeBert input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

class SmartContractVulnerabilityDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 1024,
        split: str = "train",
        vulnerability_types: List[str] = None
    ):
        """
        Args:
            data_path: Path to the CSV file containing the dataset
            tokenizer: Tokenizer for encoding the source code
            max_length: Maximum sequence length
            split: "train" or "val" to specify which split to load
            vulnerability_types: List of vulnerability types to consider
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        self.vulnerability_types = vulnerability_types or [
            'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
        ]
        
        # Load the dataset
        self.data = self._load_dataset(data_path)
        
    def _load_dataset(self, data_path: str) -> List[Dict]:
        """Load and preprocess the dataset from CSV"""
        dataset = []
        
        # Read the CSV file
        df = pd.read_csv(data_path)
        
        # Split into train/val if needed
        if self.split == "train":
            df = df.sample(frac=0.8, random_state=42)
        else:
            df = df.sample(frac=0.2, random_state=42)
        
        # Process each contract
        for _, row in df.iterrows():
            try:
                source_code = row['source_code']
                contract_name = row['contract_name']
                
                # Parse AST and get paths
                ast = parse_solidity_to_ast(source_code)
                ast_paths = prepare_code2vec_input(ast) if ast else []
                ast_path_text = ' '.join(ast_paths)
                
                # Split source code into lines
                lines = source_code.split('\n')
                
                # Create token-to-line mapping
                token_to_line = []
                current_line = 0
                
                # Tokenize each line separately to maintain mapping
                for line in lines:
                    line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
                    token_to_line.extend([current_line] * len(line_tokens))
                    current_line += 1
                
                # Add special tokens
                token_to_line = [0] + token_to_line + [0]  # [CLS] and [SEP] tokens
                
                # Truncate if too long
                if len(token_to_line) > self.max_length:
                    token_to_line = token_to_line[:self.max_length]
                
                # Pad if too short
                if len(token_to_line) < self.max_length:
                    token_to_line.extend([0] * (self.max_length - len(token_to_line)))
                
                # Create multi-label line labels for each vulnerability type
                line_labels = self._create_multi_label_line_labels(source_code, row)
                
                # Create contract-level vulnerability labels
                contract_labels = self._create_contract_vulnerability_labels(row)
                
                # Tokenize the source code
                encoding = self.tokenizer(
                    source_code,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Tokenize AST paths
                ast_encoding = self.tokenizer(
                    ast_path_text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Convert line labels to tensor and ensure consistent shape
                vuln_tensor = torch.zeros((len(self.vulnerability_types), self.max_length), dtype=torch.long)
                for i, labels in enumerate(line_labels):
                    if len(labels) > self.max_length:
                        labels = labels[:self.max_length]
                    vuln_tensor[i, :len(labels)] = torch.tensor(labels, dtype=torch.long)
                
                # Convert contract labels to tensor
                contract_vuln_tensor = torch.tensor(contract_labels, dtype=torch.long)
                
                # Convert token_to_line to tensor
                token_to_line_tensor = torch.tensor(token_to_line, dtype=torch.long)
                
                # Ensure attention masks are boolean
                attention_mask = encoding['attention_mask'].squeeze(0).bool()
                ast_attention_mask = ast_encoding['attention_mask'].squeeze(0).bool()
                
                # Ensure input_ids are the right length
                input_ids = encoding['input_ids'].squeeze(0)
                ast_input_ids = ast_encoding['input_ids'].squeeze(0)
                
                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                if len(ast_input_ids) > self.max_length:
                    ast_input_ids = ast_input_ids[:self.max_length]
                
                # Pad if necessary
                if len(input_ids) < self.max_length:
                    input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - len(input_ids)))
                if len(ast_input_ids) < self.max_length:
                    ast_input_ids = torch.nn.functional.pad(ast_input_ids, (0, self.max_length - len(ast_input_ids)))
                
                dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'ast_input_ids': ast_input_ids,
                    'ast_attention_mask': ast_attention_mask,
                    'vulnerable_lines': vuln_tensor,
                    'contract_vulnerabilities': contract_vuln_tensor,
                    'token_to_line': token_to_line_tensor,
                    'source_code': source_code,
                    'contract_name': contract_name
                })
            except Exception as e:
                print(f"Error processing contract {contract_name}: {str(e)}")
                continue
        
        return dataset
    
    def _create_contract_vulnerability_labels(self, row: pd.Series) -> List[int]:
        """Create contract-level vulnerability labels"""
        contract_labels = []
        for vuln_type in self.vulnerability_types:
            # Check if contract has this vulnerability type
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Contract is vulnerable if it has any vulnerable lines
            has_vulnerability = len(vuln_lines) > 0
            contract_labels.append(1 if has_vulnerability else 0)
        
        return contract_labels
    
    def _create_multi_label_line_labels(self, source_code: str, row: pd.Series) -> List[List[int]]:
        """Create multi-label line labels for each vulnerability type"""
        total_lines = len(source_code.split('\n'))
        line_labels = {vuln_type: [0] * total_lines for vuln_type in self.vulnerability_types}
        
        # Process each vulnerability type
        for vuln_type in self.vulnerability_types:
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Process each vulnerable line/snippet
            for line_or_snippet in vuln_lines:
                if isinstance(line_or_snippet, int):
                    # If it's a line number, mark that line
                    if 0 <= line_or_snippet < total_lines:
                        line_labels[vuln_type][line_or_snippet] = 1
                else:
                    # If it's a code snippet, find matching lines
                    source_lines = source_code.split('\n')
                    for i, line in enumerate(source_lines):
                        # Clean both the line and snippet for comparison
                        clean_line = re.sub(r'\s+', ' ', line.strip())
                        clean_snippet = re.sub(r'\s+', ' ', str(line_or_snippet).strip())
                        if clean_snippet in clean_line:
                            line_labels[vuln_type][i] = 1
        
        # Convert to list format
        return [line_labels[vuln_type] for vuln_type in self.vulnerability_types]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable length inputs
    """
    # Get the maximum length in this batch for each type of tensor
    max_input_len = max(item['input_ids'].size(0) for item in batch)
    
    # Pad all tensors to their respective maximum lengths
    padded_batch = {
        'input_ids': torch.stack([
            torch.nn.functional.pad(item['input_ids'], (0, max_input_len - item['input_ids'].size(0)))
            for item in batch
        ]),
        'attention_mask': torch.stack([
            torch.nn.functional.pad(item['attention_mask'], (0, max_input_len - item['attention_mask'].size(0)))
            for item in batch
        ]),
        'ast_input_ids': torch.stack([item['ast_input_ids'] for item in batch]),
        'ast_attention_mask': torch.stack([item['ast_attention_mask'] for item in batch]),
        'vulnerable_lines': torch.stack([item['vulnerable_lines'] for item in batch]),
        'contract_vulnerabilities': torch.stack([item['contract_vulnerabilities'] for item in batch]),
        'token_to_line': torch.stack([item['token_to_line'] for item in batch]),
        'source_code': [item['source_code'] for item in batch],
        'contract_name': [item['contract_name'] for item in batch]
    }
    
    return padded_batch


# 0 Constants:

In [3]:
DATA_PATH = "contract_sources_with_vulnerabilities_2048_token_size.csv"
MODEL_PATH = "checkpoints_v5_2048_output/best_model_augmented_gan_epoch_91.pt"
TOKENIZER_NAME = "microsoft/codebert-base"
VULNERABILITY_TYPES = ['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
DEVICE = torch.device("cuda")
MODEL_LINE_CODE_VULNERABILITY_THRESHOLD = 0.2

# 1 Load Dataset:

In [4]:
def create_dataloaders(
    data_path: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 1024,
    num_workers: int = 4,
    vulnerability_types: List[str] = None
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    
    val_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="val",
        vulnerability_types=vulnerability_types
    )
    
    dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    return  dataloader

In [5]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

val_dataloader = create_dataloaders(
    data_path=DATA_PATH,
    tokenizer=tokenizer,
    batch_size=8,
    max_length=1024,
    vulnerability_types=VULNERABILITY_TYPES
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


In [6]:
val_dataloader.dataset.data[0]['vulnerable_lines'][7]

tensor([0, 0, 0,  ..., 0, 0, 0])

## 1.1. ML performance Metrics

In [7]:
def calculate_precision(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate precision for vulnerability detection."""
    if np.sum(pred_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(pred_labels)

def calculate_recall(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate recall for vulnerability detection."""
    if np.sum(true_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(true_labels)

def calculate_f1_score(precision: float, recall: float) -> float:
    """Calculate F1 score."""
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)

def calculate_line_accuracy(true_line_vulns: np.ndarray, pred_line_vulns: Dict) -> float:
    """Calculate line-level accuracy (simplified)."""
    try:
        # Convert predicted line vulnerabilities to array format
        pred_array = np.zeros_like(true_line_vulns)
        for line_num, line_vulns in pred_line_vulns.items():
            if line_num < pred_array.shape[0]:
                for vuln_idx, is_vuln in enumerate(line_vulns.values()):
                    if vuln_idx < pred_array.shape[1]:
                        pred_array[line_num, vuln_idx] = 1 if is_vuln else 0
        
        return np.mean(pred_array == true_line_vulns)
    except:
        return 0.

def get_vulnerability_details(analyzer, true_vulns: np.ndarray, pred_vulns: np.ndarray, 
                            probabilities: List[float]) -> Dict[str, Any]:
    """Get detailed vulnerability analysis."""
    details = {}
    for i, vuln_type in enumerate(analyzer.vulnerability_types):
        details[vuln_type] = {
            'true_positive': bool(true_vulns[i] == 1 and pred_vulns[i] == 1),
            'false_positive': bool(true_vulns[i] == 0 and pred_vulns[i] == 1),
            'false_negative': bool(true_vulns[i] == 1 and pred_vulns[i] == 0),
            'true_negative': bool(true_vulns[i] == 0 and pred_vulns[i] == 0),
            'probability': probabilities[i],
            'true_label': int(true_vulns[i]),
            'predicted_label': int(pred_vulns[i])
        }
    return details
    

## 1.2 Load Model:

In [8]:
model_path=MODEL_PATH
device = DEVICE
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token



In [9]:
model = SmartContractTransformer(
        d_model=768,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        max_length=1024,
        vocab_size=tokenizer.vocab_size,
        num_vulnerability_types=8,
        use_gan=True 
    )

checkpoint = torch.load(model_path, map_location=device)



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights


In [10]:
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from {model_path}")
    print(f"Training epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"Best validation loss: {checkpoint.get('val_loss', 'Unknown')}")
    print(f"Training config: GAN={checkpoint.get('use_gan', 'Unknown')}")
else:
    # Direct state dict
    model.load_state_dict(checkpoint)
    print(f"Model loaded from {model_path}")

Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_91.pt
Training epoch: 91
Best validation loss: 0.83826167229563
Training config: GAN=True


In [11]:
model.to(device)
model.eval()
    
vulnerability_types = [
    'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
]

In [12]:
print(val_dataloader.dataset.data[5]['source_code'])

/**
 * Source Code first verified at https://etherscan.io on Tuesday, December 18, 2018
 (UTC) */

pragma solidity ^0.4.23;

library SafeMath {

  /**
  * @dev Multiplies two numbers, throws on overflow.
  */
  function mul(uint256 a, uint256 b) internal pure returns (uint256 c) {

    if (a == 0) {
      return 0;
    }

    c = a * b;
    assert(c / a == b);
    return c;
  }

  /**
  * @dev Integer division of two numbers, truncating the quotient.
  */
  function div(uint256 a, uint256 b) internal pure returns (uint256) {

    return a / b;
  }

  /**
  * @dev Subtracts two numbers, throws on overflow (i.e. if subtrahend is greater than minuend).
  */
  function sub(uint256 a, uint256 b) internal pure returns (uint256) {
    assert(b <= a);
    return a - b;
  }

  /**
  * @dev Adds two numbers, throws on overflow.
  */
  function add(uint256 a, uint256 b) internal pure returns (uint256 c) {
    c = a + b;
    assert(c >= a);
    return c;
  }
}


contract ERC20Basic {
    
  functi

In [13]:
contract_code = val_dataloader.dataset.data[5]['source_code']
ast = parse_solidity_to_ast(contract_code)
ast_paths = prepare_code2vec_input(ast) if ast else []
ast_path_text = ' '.join(ast_paths)

# Tokenize inputs
contract_encoding = tokenizer(
    contract_code,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

ast_encoding = tokenizer(
    ast_path_text,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

In [14]:
input_ids = contract_encoding['input_ids'].to(device)
attention_mask = contract_encoding['attention_mask'].to(device)
ast_input_ids = ast_encoding['input_ids'].to(device)
ast_attention_mask = ast_encoding['attention_mask'].to(device)

print(f"DEBUG: input_ids shape: {input_ids.shape}")
print(f"DEBUG: actual sequence length: {input_ids.shape[1]}")

# Create proper token-to-line mapping that matches the actual tokenization
lines = contract_code.split('\n')
token_to_line = []
current_line = 0

# Add [CLS] token mapping (line 0)
token_to_line.append(0)

for line in lines:
    line_tokens = tokenizer.encode(line, add_special_tokens=False)
    # Map all tokens in this line to the same line number
    token_to_line.extend([current_line] * len(line_tokens))
    current_line += 1

# Add [SEP] token mapping (line 0)
token_to_line.append(0)

# FIXED: Ensure the mapping matches the actual tokenized length
actual_seq_len = input_ids.shape[1]  # Get the actual sequence length from input_ids
if len(token_to_line) > actual_seq_len:
    token_to_line = token_to_line[:actual_seq_len]
elif len(token_to_line) < actual_seq_len:
    token_to_line.extend([0] * (actual_seq_len - len(token_to_line)))

token_to_line = torch.tensor(token_to_line, dtype=torch.long).to(device)


DEBUG: input_ids shape: torch.Size([1, 1024])
DEBUG: actual sequence length: 1024


In [15]:
# Enable debug mode
model._debug_mode = False

with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ast_input_ids=ast_input_ids,
        ast_attention_mask=ast_attention_mask,
        target_ids=input_ids,  # Use input_ids as target for inference
        token_to_line=token_to_line
    )

contract_vuln_logits = outputs.get('contract_vulnerability_logits')
line_vuln_logits = outputs.get('line_vulnerability_logits')

line_vuln_probs = torch.sigmoid(line_vuln_logits)  # Shape: [1, num_lines, 8]

for line_idx in range(line_vuln_probs.shape[1]):
    line_probs = line_vuln_probs[0, line_idx, :]  # Shape: [8]
    # Check if any vulnerability type has high probability
    is_vulnerable = (line_probs > 0.1).any()
    if is_vulnerable:
        print(f"Line {line_idx} is vulnerable:")
        for i, vuln_type in enumerate(VULNERABILITY_TYPES):
            if line_probs[i] > 0.1 and line_probs[i] < 0.5000:
                print(f"  {vuln_type}: {line_probs[i]:.4f}")

print("✅ Test completed successfully!") 

Line 18 is vulnerable:
  ARTHM: 0.1058
Line 43 is vulnerable:
  ARTHM: 0.3970
Line 44 is vulnerable:
  ARTHM: 0.1128
Line 66 is vulnerable:
  TimeM: 0.1090
Line 67 is vulnerable:
  ARTHM: 0.1214
Line 92 is vulnerable:
  TimeM: 0.1630
Line 122 is vulnerable:
Line 123 is vulnerable:
Line 124 is vulnerable:
Line 125 is vulnerable:
Line 126 is vulnerable:
Line 127 is vulnerable:
Line 128 is vulnerable:
Line 129 is vulnerable:
Line 130 is vulnerable:
Line 131 is vulnerable:
Line 132 is vulnerable:
Line 133 is vulnerable:
Line 134 is vulnerable:
Line 135 is vulnerable:
Line 136 is vulnerable:
Line 137 is vulnerable:
Line 138 is vulnerable:
Line 139 is vulnerable:
Line 140 is vulnerable:
Line 141 is vulnerable:
Line 142 is vulnerable:
Line 143 is vulnerable:
Line 144 is vulnerable:
Line 145 is vulnerable:
Line 146 is vulnerable:
Line 147 is vulnerable:
Line 148 is vulnerable:
Line 149 is vulnerable:
Line 150 is vulnerable:
Line 151 is vulnerable:
Line 152 is vulnerable:
Line 153 is vulnerable

In [29]:
line_tokens = tokenizer.encode(lines[5], add_special_tokens=False)
line_tokens

[50121]