In [1]:
# uv pip install torch transformers numpy

In [1]:
# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer

# Training imports
from train import SmartContractTrainer

# Data processing imports
from data_processing import SmartContractDataset, preprocess_contract

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging 

In [2]:
print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

CUDA is available: True
Number of GPUs: 2


In [3]:
from datasets import load_dataset

ds = load_dataset("jainabh/smart_contracts_malicious")

Using the latest cached version of the dataset since jainabh/smart_contracts_malicious couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/m20180848/.cache/huggingface/datasets/jainabh___smart_contracts_malicious/default/0.0.0/ba1b2cb9d02f16e398b6e50f59db70c5c3cb1b25 (last modified on Sat May  3 17:02:56 2025).


In [4]:
ds

DatasetDict({
    train: Dataset({
        features: ['contract_source', 'malicious'],
        num_rows: 2000
    })
})

In [9]:
# Load and preprocess data
train_contracts = ds['train'][0:1400]['contract_source']  # Changed from 'contract_source' to 'source_code'
train_labels = ds['train'][0:1400]['malicious']
val_contracts = ds['train'][1400:-1]['contract_source']  # Changed from 'contract_source' to 'source_code'
val_labels = ds['train'][1400:-1]['malicious']

In [10]:
import re
from typing import List, Dict, Any
import json

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to code2vec input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

# Example usage:
def process_contract_for_code2vec(code: str) -> List[str]:
    """
    Process a Solidity contract for code2vec
    """
    # Parse code to AST
    ast = parse_solidity_to_ast(code)
    if ast is None:
        return []
    
    # Convert AST to code2vec input format
    paths = prepare_code2vec_input(ast)
    return paths

In [11]:
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, AutoModel  # Add these imports

class SmartContractDatasetWithPaths(Dataset):
    def __init__(self, contracts, labels, tokenizer, code2vec_model, max_length=256):
        self.contracts = contracts
        self.labels = labels
        self.tokenizer = tokenizer
        self.code2vec_model = code2vec_model
        self.max_length = max_length
        
    def __len__(self):
        return len(self.contracts)
    
    def __getitem__(self, idx):
        contract = self.contracts[idx]
        label = self.labels[idx]
        
        # Parse contract to AST and generate paths
        ast = parse_solidity_to_ast(contract)
        paths = prepare_code2vec_input(ast)
        
        # Convert paths to string for tokenization
        paths_str = ' '.join([''.join(path) for path in paths])
        
        # Tokenize contract with consistent max_length
        contract_inputs = self.tokenizer(
            contract,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Tokenize paths with consistent max_length
        path_inputs = self.tokenizer(
            paths_str,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # Get tensors and ensure they're 1D
        input_ids = contract_inputs['input_ids'].squeeze(0)
        attention_mask = contract_inputs['attention_mask'].squeeze(0).bool()  # Convert to boolean
        path_input_ids = path_inputs['input_ids'].squeeze(0)
        path_attention_mask = path_inputs['attention_mask'].squeeze(0).bool()  # Convert to boolean
        
        # Create target_ids with the same length as input_ids
        target_ids = input_ids.clone()
        
        # Ensure all tensors have the same length
        if input_ids.size(0) < self.max_length:
            padding = torch.zeros(self.max_length - input_ids.size(0), dtype=input_ids.dtype)
            input_ids = torch.cat([input_ids, padding])
            attention_mask = torch.cat([attention_mask, torch.zeros(self.max_length - attention_mask.size(0), dtype=torch.bool)])  # Use boolean padding
            target_ids = torch.cat([target_ids, padding])
        elif input_ids.size(0) > self.max_length:
            input_ids = input_ids[:self.max_length]
            attention_mask = attention_mask[:self.max_length]
            target_ids = target_ids[:self.max_length]
        
        # Ensure path tensors have the same length
        if path_input_ids.size(0) < self.max_length:
            padding = torch.zeros(self.max_length - path_input_ids.size(0), dtype=path_input_ids.dtype)
            path_input_ids = torch.cat([path_input_ids, padding])
            path_attention_mask = torch.cat([path_attention_mask, torch.zeros(self.max_length - path_attention_mask.size(0), dtype=torch.bool)])  # Use boolean padding
        elif path_input_ids.size(0) > self.max_length:
            path_input_ids = path_input_ids[:self.max_length]
            path_attention_mask = path_attention_mask[:self.max_length]
        
        # Ensure all tensors are 1D
        input_ids = input_ids.view(-1)
        attention_mask = attention_mask.view(-1)
        path_input_ids = path_input_ids.view(-1)
        path_attention_mask = path_attention_mask.view(-1)
        target_ids = target_ids.view(-1)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'path_input_ids': path_input_ids,
            'path_attention_mask': path_attention_mask,
            'target_ids': target_ids,
            'label': torch.tensor(label, dtype=torch.float)
        }

In [12]:
# Initialize tokenizer and code2vec model
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
code2vec_model = AutoModel.from_pretrained('microsoft/codebert-base').cuda()

# Create datasets
train_dataset = SmartContractDatasetWithPaths(
    train_contracts, 
    train_labels,
    tokenizer,
    code2vec_model
)

val_dataset = SmartContractDatasetWithPaths(
    val_contracts,
    val_labels,
    tokenizer,
    code2vec_model
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False
)

In [11]:
# Initialize trainer
from train import VulnerabilityDetectionTrainer
import time
from datetime import datetime
import torch
import os

# Create a directory for checkpoints
checkpoint_dir = 'v3-checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize trainer
trainer = VulnerabilityDetectionTrainer(
    model,
    train_dataloader,
    val_dataloader
)

In [14]:
checkpoint_path

'v3-checkpoints/checkpoint_epoch_30_model_v3.pt'

In [16]:
# Load checkpoint
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_20_model_v3.pt')  # Change this to your checkpoint file
checkpoint = torch.load(checkpoint_path)

# Load model states
model.load_state_dict(checkpoint['model_state_dict'])
model.generator.load_state_dict(checkpoint['generator_state_dict'])
model.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
model.decoder.load_state_dict(checkpoint['decoder_state_dict'])

# Load optimizer states
trainer.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
trainer.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
trainer.optimizer_decoder.load_state_dict(checkpoint['optimizer_decoder_state_dict'])

# Get the epoch to start from and best validation loss
start_epoch = checkpoint['epoch']
best_val_loss = checkpoint['val_loss']

print(f"Loaded checkpoint from epoch {start_epoch + 1}")
print(f"Previous validation loss: {best_val_loss:.4f}")

Loaded checkpoint from epoch 20
Previous validation loss: 0.0070


In [None]:
# Training loop - start from the next epoch
num_epochs = 120

for epoch in range(start_epoch + 1, num_epochs):  # Start from the next epoch
    # Start timer for this epoch
    epoch_start_time = time.time()
    
    # Training
    g_loss, d_loss, decoder_loss = trainer.train_epoch()
    val_loss = trainer.validate()
    
    # Calculate epoch time
    epoch_time = time.time() - epoch_start_time
    
    # Print training progress
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Generator Loss: {g_loss:.4f}")
    print(f"Discriminator Loss: {d_loss:.4f}")
    print(f"Decoder Loss: {decoder_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Epoch Time: {epoch_time:.2f}s")
    
    # Save model checkpoint
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            # Model states
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'generator_state_dict': model.generator.state_dict(),
            'discriminator_state_dict': model.discriminator.state_dict(),
            'decoder_state_dict': model.decoder.state_dict(),
            
            # Optimizer states
            'optimizer_G_state_dict': trainer.optimizer_G.state_dict(),
            'optimizer_D_state_dict': trainer.optimizer_D.state_dict(),
            'optimizer_decoder_state_dict': trainer.optimizer_decoder.state_dict(),
            
            # Loss values
            'g_loss': g_loss,
            'd_loss': d_loss,
            'decoder_loss': decoder_loss,
            'val_loss': val_loss,
            
            # Model configuration
            'model_config': {
                'vocab_size': model.decoder.vocab_size,
                'max_length': model.decoder.max_length
            },
            
            # Training metadata
            'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'epoch_time': epoch_time
        }
        
        # Save regular checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}_model_v3.pt')
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved checkpoint for epoch {epoch+1}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(checkpoint_dir, 'best_model_v3.pt')
            torch.save(checkpoint, best_model_path)
            print(f"New best model saved with validation loss: {val_loss:.4f}")

print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")


Starting training epoch...


In [None]:
print('Done')

NOTES:

1. Input Processing:
Initial input: [32, 512] (batch_size=32, sequence_length=512)
After embedding: [32, 512, 512] (batch_size=32, sequence_length=512, embedding_dim=512)
This is correct because the embedding layer converts each token to a 512-dimensional vector

2. Path Embeddings Processing:
Initial path embeddings: [32, 768] (batch_size=32, code2vec_dim=768)
After path embedding layer: [32, 512] (batch_size=32, transformer_dim=512)
The linear layer converts from code2vec's 768 dimensions to transformer's 512 dimensions
After expansion: [32, 512, 512] (batch_size=32, sequence_length=512, transformer_dim=512)
The path embeddings are expanded to match the sequence length

3. Final Shape:
[32, 512, 512] (batch_size=32, sequence_length=512, transformer_dim=512)
This is the correct shape for the transformer layers


In [15]:
checkpoint = {
    # Model states
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'generator_state_dict': model.generator.state_dict(),
    'discriminator_state_dict': model.discriminator.state_dict(),
    'decoder_state_dict': model.decoder.state_dict(),
    
    # Optimizer states
    'optimizer_G_state_dict': trainer.optimizer_G.state_dict(),
    'optimizer_D_state_dict': trainer.optimizer_D.state_dict(),
    'optimizer_decoder_state_dict': trainer.optimizer_decoder.state_dict(),
    
    # Loss values
    'g_loss': g_loss,
    'd_loss': d_loss,
    'decoder_loss': decoder_loss,
    'val_loss': val_loss,
    
    # Model configuration
    'model_config': {
        #'d_model': model.d_model,
        'vocab_size': model.decoder.vocab_size,
        'max_length': model.decoder.max_length
    },
    
    # Training metadata
    'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    'epoch_time': epoch_time
}

# Save regular checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}_model_v3.pt')
torch.save(checkpoint, checkpoint_path)
print(f"Saved checkpoint for epoch {epoch+1}")

Saved checkpoint for epoch 10


In [25]:
print(f"Epoch [{epoch}/{num_epochs}]")
print(f"Generator Loss: {g_loss:.4f}")
print(f"Discriminator Loss: {d_loss:.4f}")
print(f"Validation Loss: {val_loss:.4f}")

Epoch [119/120]
Generator Loss: 10.5703
Discriminator Loss: 0.0062
Validation Loss: 0.0002


# 1. Load Model:

In [13]:
import torch
from model import SmartContractTransformer
from train import Discriminator

def load_trained_model(checkpoint_path, device='cuda:1'):
    """
    Load the trained model and discriminator from checkpoint
    """
    # Initialize model and discriminator
    model = SmartContractTransformer()
    discriminator = Discriminator()
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load state dicts
    model.load_state_dict(checkpoint['model_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    
    # Move to device
    model = model.to(device)
    discriminator = discriminator.to(device)
    
    # Set to eval mode
    model.eval()
    discriminator.eval()
    
    return model, discriminator


In [14]:
# Load the trained model
checkpoint_path = 'checkpoints_v1/latest_model.pt'  # or 'best_model_epoch_X.pt'
model_loaded, discriminator = load_trained_model(checkpoint_path)

# 2. Model Validation:

In [15]:
contract_data = val_dataset[9]

In [16]:
def analyze_contract(contract_data, model, discriminator, tokenizer=None, device='cuda:1'):
    """
    Analyze a smart contract for vulnerabilities and generate synthetic version
    """
    # Move input data to device
    input_ids = contract_data['input_ids'].unsqueeze(0).to(device)
    attention_mask = contract_data['attention_mask'].unsqueeze(0).to(device)
    path_input_ids = contract_data['path_input_ids'].unsqueeze(0).to(device)
    path_attention_mask = contract_data['path_attention_mask'].unsqueeze(0).to(device)
    
    # Decode original contract if tokenizer is provided
    original_text = None
    if tokenizer is not None:
        original_text = tokenizer.decode(input_ids[0].cpu().tolist())
    
    with torch.no_grad():
        # Get vulnerability score
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            path_input_ids=path_input_ids,
            path_attention_mask=path_attention_mask,
            target_ids=None
        )
        
        # Get discriminator predictions
        vuln_pred, synth_pred = discriminator(outputs['encoder_output'])
        vulnerability_score = torch.sigmoid(vuln_pred).item()
        synthetic_score = torch.sigmoid(synth_pred).item()
        
        result = {
            'vulnerability_score': vulnerability_score,
            'synthetic_score': synthetic_score,
            'is_vulnerable': vulnerability_score > 0.5,
            'is_synthetic': synthetic_score > 0.5,
            'original_contract': {
                'text': original_text,
                'input_ids': input_ids[0].cpu().tolist()
            }
        }
        
        # Generate synthetic contract
        synthetic_outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            path_input_ids=path_input_ids,
            path_attention_mask=path_attention_mask,
            target_ids=None
        )
        
        # Get the generated sequence
        generated_sequence = synthetic_outputs['generated_sequence']
        
        # Get discriminator predictions for synthetic contract
        synth_vuln_pred, synth_synth_pred = discriminator(synthetic_outputs['encoder_output'])
        synth_vulnerability_score = torch.sigmoid(synth_vuln_pred).item()
        synth_synthetic_score = torch.sigmoid(synth_synth_pred).item()
        
        result['synthetic_contract'] = {
            'sequence': generated_sequence[0].cpu().tolist(),
            'vulnerability_score': synth_vulnerability_score,
            'synthetic_score': synth_synthetic_score,
            'is_vulnerable': synth_vulnerability_score > 0.5,
            'is_synthetic': synth_synthetic_score > 0.5
        }
        
        # If tokenizer is provided, decode the synthetic contract
        if tokenizer is not None:
            synthetic_tokens = result['synthetic_contract']['sequence']
            result['synthetic_contract']['text'] = tokenizer.decode(synthetic_tokens)
        
        return result

In [17]:
# Analyze the contract
results = analyze_contract(contract_data, model_loaded, discriminator, tokenizer)

# Print results
print("\nContract Analysis Results:")
print("-" * 50)
print(f"Vulnerability Score: {results['vulnerability_score']:.4f}")
print(f"Synthetic Score: {results['synthetic_score']:.4f}")
print(f"Vulnerability Status: {'Vulnerable' if results['is_vulnerable'] else 'Safe'}")
print(f"Synthetic Status: {'Synthetic' if results['is_synthetic'] else 'Real'}")

# Print original contract
if results['original_contract']['text']:
    print("\nOriginal Contract:")
    print("-" * 50)
    print(results['original_contract']['text'])
    

print("\nSynthetic Contract Analysis:")
print("-" * 50)
print(f"Vulnerability Score: {results['synthetic_contract']['vulnerability_score']:.4f}")
print(f"Synthetic Score: {results['synthetic_contract']['synthetic_score']:.4f}")
print(f"Vulnerability Status: {'Vulnerable' if results['synthetic_contract']['is_vulnerable'] else 'Safe'}")
print(f"Synthetic Status: {'Synthetic' if results['synthetic_contract']['is_synthetic'] else 'Real'}")

if 'text' in results['synthetic_contract']:
    print("\nGenerated Synthetic Contract:")
    print("-" * 50)
    print(results['synthetic_contract']['text'])

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)



Contract Analysis Results:
--------------------------------------------------
Vulnerability Score: 0.0151
Synthetic Score: 0.7673
Vulnerability Status: Safe
Synthetic Status: Synthetic

Original Contract:
--------------------------------------------------
<s>/**
 *Submitted for verification at Etherscan.io on 2020-11-22
*/

// SPDX-License-Identifier: MIT + WTFPL
// File: contracts/uniswapv2/interfaces/IUniswapV2Factory.sol

pragma solidity >=0.5.0;

interface IUniswapV2Factory {
    event PairCreated(address indexed token0, address indexed token1, address pair, uint);

    function feeTo() external view returns (address);
    function feeToSetter() external view returns (address);
    function migrator() external view returns (address);

    function getPair(address tokenA, address tokenB) external view returns (address pair);
    function allPairs(uint) external view returns (address pair);
    function allPairsLength() external view returns (uint);

    function createPair(address 