In [1]:
from test_model import SmartContractAnalyzer, print_vulnerability_report
from inference import *

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging

from transformers import AutoTokenizer
import json
import os
import pandas as pd
from typing import Dict, List, Tuple, Any
import re

  _torch_pytree._register_pytree_node(


# 1. Load Validation Dataset:

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to codeBert input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

class SmartContractVulnerabilityDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 1024,
        split: str = "train",
        vulnerability_types: List[str] = None
    ):
        """
        Args:
            data_path: Path to the CSV file containing the dataset
            tokenizer: Tokenizer for encoding the source code
            max_length: Maximum sequence length
            split: "train" or "val" to specify which split to load
            vulnerability_types: List of vulnerability types to consider
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        self.vulnerability_types = vulnerability_types or [
            'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
        ]
        
        # Load the dataset
        self.data = self._load_dataset(data_path)
        
    def _load_dataset(self, data_path: str) -> List[Dict]:
        """Load and preprocess the dataset from CSV"""
        dataset = []
        
        # Read the CSV file
        df = pd.read_csv(data_path)
        
        # Split into train/val if needed
        if self.split == "train":
            df = df.sample(frac=0.8, random_state=42)
        else:
            df = df.sample(frac=0.2, random_state=42)
        
        # Process each contract
        for _, row in df.iterrows():
            try:
                source_code = row['source_code']
                contract_name = row['contract_name']
                
                # Parse AST and get paths
                ast = parse_solidity_to_ast(source_code)
                ast_paths = prepare_code2vec_input(ast) if ast else []
                ast_path_text = ' '.join(ast_paths)
                
                # Split source code into lines
                lines = source_code.split('\n')
                
                # Create token-to-line mapping
                token_to_line = []
                current_line = 0
                
                # Tokenize each line separately to maintain mapping
                for line in lines:
                    line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
                    token_to_line.extend([current_line] * len(line_tokens))
                    current_line += 1
                
                # Add special tokens
                token_to_line = [0] + token_to_line + [0]  # [CLS] and [SEP] tokens
                
                # Truncate if too long
                if len(token_to_line) > self.max_length:
                    token_to_line = token_to_line[:self.max_length]
                
                # Pad if too short
                if len(token_to_line) < self.max_length:
                    token_to_line.extend([0] * (self.max_length - len(token_to_line)))
                
                # Create multi-label line labels for each vulnerability type
                line_labels = self._create_multi_label_line_labels(source_code, row)
                
                # Create contract-level vulnerability labels
                contract_labels = self._create_contract_vulnerability_labels(row)
                
                # Tokenize the source code
                encoding = self.tokenizer(
                    source_code,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Tokenize AST paths
                ast_encoding = self.tokenizer(
                    ast_path_text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Convert line labels to tensor and ensure consistent shape
                vuln_tensor = torch.zeros((len(self.vulnerability_types), self.max_length), dtype=torch.long)
                for i, labels in enumerate(line_labels):
                    if len(labels) > self.max_length:
                        labels = labels[:self.max_length]
                    vuln_tensor[i, :len(labels)] = torch.tensor(labels, dtype=torch.long)
                
                # Convert contract labels to tensor
                contract_vuln_tensor = torch.tensor(contract_labels, dtype=torch.long)
                
                # Convert token_to_line to tensor
                token_to_line_tensor = torch.tensor(token_to_line, dtype=torch.long)
                
                # Ensure attention masks are boolean
                attention_mask = encoding['attention_mask'].squeeze(0).bool()
                ast_attention_mask = ast_encoding['attention_mask'].squeeze(0).bool()
                
                # Ensure input_ids are the right length
                input_ids = encoding['input_ids'].squeeze(0)
                ast_input_ids = ast_encoding['input_ids'].squeeze(0)
                
                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                if len(ast_input_ids) > self.max_length:
                    ast_input_ids = ast_input_ids[:self.max_length]
                
                # Pad if necessary
                if len(input_ids) < self.max_length:
                    input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - len(input_ids)))
                if len(ast_input_ids) < self.max_length:
                    ast_input_ids = torch.nn.functional.pad(ast_input_ids, (0, self.max_length - len(ast_input_ids)))
                
                dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'ast_input_ids': ast_input_ids,
                    'ast_attention_mask': ast_attention_mask,
                    'vulnerable_lines': vuln_tensor,
                    'contract_vulnerabilities': contract_vuln_tensor,
                    'token_to_line': token_to_line_tensor,
                    'source_code': source_code,
                    'contract_name': contract_name
                })
            except Exception as e:
                print(f"Error processing contract {contract_name}: {str(e)}")
                continue
        
        return dataset
    
    def _create_contract_vulnerability_labels(self, row: pd.Series) -> List[int]:
        """Create contract-level vulnerability labels"""
        contract_labels = []
        for vuln_type in self.vulnerability_types:
            # Check if contract has this vulnerability type
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Contract is vulnerable if it has any vulnerable lines
            has_vulnerability = len(vuln_lines) > 0
            contract_labels.append(1 if has_vulnerability else 0)
        
        return contract_labels
    
    def _create_multi_label_line_labels(self, source_code: str, row: pd.Series) -> List[List[int]]:
        """Create multi-label line labels for each vulnerability type"""
        total_lines = len(source_code.split('\n'))
        line_labels = {vuln_type: [0] * total_lines for vuln_type in self.vulnerability_types}
        
        # Process each vulnerability type
        for vuln_type in self.vulnerability_types:
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Process each vulnerable line/snippet
            for line_or_snippet in vuln_lines:
                if isinstance(line_or_snippet, int):
                    # If it's a line number, mark that line
                    if 0 <= line_or_snippet < total_lines:
                        line_labels[vuln_type][line_or_snippet] = 1
                else:
                    # If it's a code snippet, find matching lines
                    source_lines = source_code.split('\n')
                    for i, line in enumerate(source_lines):
                        # Clean both the line and snippet for comparison
                        clean_line = re.sub(r'\s+', ' ', line.strip())
                        clean_snippet = re.sub(r'\s+', ' ', str(line_or_snippet).strip())
                        if clean_snippet in clean_line:
                            line_labels[vuln_type][i] = 1
        
        # Convert to list format
        return [line_labels[vuln_type] for vuln_type in self.vulnerability_types]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable length inputs
    """
    # Get the maximum length in this batch for each type of tensor
    max_input_len = max(item['input_ids'].size(0) for item in batch)
    
    # Pad all tensors to their respective maximum lengths
    padded_batch = {
        'input_ids': torch.stack([
            torch.nn.functional.pad(item['input_ids'], (0, max_input_len - item['input_ids'].size(0)))
            for item in batch
        ]),
        'attention_mask': torch.stack([
            torch.nn.functional.pad(item['attention_mask'], (0, max_input_len - item['attention_mask'].size(0)))
            for item in batch
        ]),
        'ast_input_ids': torch.stack([item['ast_input_ids'] for item in batch]),
        'ast_attention_mask': torch.stack([item['ast_attention_mask'] for item in batch]),
        'vulnerable_lines': torch.stack([item['vulnerable_lines'] for item in batch]),
        'contract_vulnerabilities': torch.stack([item['contract_vulnerabilities'] for item in batch]),
        'token_to_line': torch.stack([item['token_to_line'] for item in batch]),
        'source_code': [item['source_code'] for item in batch],
        'contract_name': [item['contract_name'] for item in batch]
    }
    
    return padded_batch

def create_dataloaders(
    data_path: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 1024,
    num_workers: int = 4,
    vulnerability_types: List[str] = None
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    """
    Create train and validation dataloaders
    
    Args:
        data_path: Path to the CSV file containing the dataset
        tokenizer: Tokenizer for encoding the source code
        batch_size: Batch size for training
        max_length: Maximum sequence length
        num_workers: Number of workers for data loading
        vulnerability_types: List of vulnerability types to consider
    
    Returns:
        Tuple of (train_dataloader, val_dataloader)
    """
    # Create datasets
    train_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="train",
        vulnerability_types=vulnerability_types
    )
    
    val_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="val",
        vulnerability_types=vulnerability_types
    )
    
    # Create dataloaders with custom collate function
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    return train_dataloader, val_dataloader

In [3]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

# Create dataloaders
train_dataloader, val_dataloader = create_dataloaders(
    data_path="contract_sources_with_vulnerabilities_2048_token_size.csv",
    tokenizer=tokenizer,
    batch_size=8,
    max_length=1024,
    vulnerability_types=['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


In [4]:
val_dataloader.dataset.max_length

1024

# 2. Load Model and SmartContractAnalyzer

In [5]:
def calculate_precision(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate precision for vulnerability detection."""
    if np.sum(pred_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(pred_labels)

def calculate_recall(true_labels: np.ndarray, pred_labels: np.ndarray) -> float:
    """Calculate recall for vulnerability detection."""
    if np.sum(true_labels) == 0:
        return 0.0
    return np.sum((true_labels == 1) & (pred_labels == 1)) / np.sum(true_labels)

def calculate_f1_score(precision: float, recall: float) -> float:
    """Calculate F1 score."""
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)

def calculate_line_accuracy(true_line_vulns: np.ndarray, pred_line_vulns: Dict) -> float:
    """Calculate line-level accuracy (simplified)."""
    try:
        # Convert predicted line vulnerabilities to array format
        pred_array = np.zeros_like(true_line_vulns)
        for line_num, line_vulns in pred_line_vulns.items():
            if line_num < pred_array.shape[0]:
                for vuln_idx, is_vuln in enumerate(line_vulns.values()):
                    if vuln_idx < pred_array.shape[1]:
                        pred_array[line_num, vuln_idx] = 1 if is_vuln else 0
        
        return np.mean(pred_array == true_line_vulns)
    except:
        return 0.

def get_vulnerability_details(analyzer, true_vulns: np.ndarray, pred_vulns: np.ndarray, 
                            probabilities: List[float]) -> Dict[str, Any]:
    """Get detailed vulnerability analysis."""
    details = {}
    for i, vuln_type in enumerate(analyzer.vulnerability_types):
        details[vuln_type] = {
            'true_positive': bool(true_vulns[i] == 1 and pred_vulns[i] == 1),
            'false_positive': bool(true_vulns[i] == 0 and pred_vulns[i] == 1),
            'false_negative': bool(true_vulns[i] == 1 and pred_vulns[i] == 0),
            'true_negative': bool(true_vulns[i] == 0 and pred_vulns[i] == 0),
            'probability': probabilities[i],
            'true_label': int(true_vulns[i]),
            'predicted_label': int(pred_vulns[i])
        }
    return details
    

In [6]:
analyzer = SmartContractAnalyzer(
    model_path="checkpoints_v2_2048_output/best_model_epoch_126.pt",
    device="cuda" if torch.cuda.is_available() else "cpu"
)



In [7]:
batch_idx = 0
source_code = val_dataloader.dataset.data[0]['source_code']
contract_name = f'Contract_{batch_idx}'
true_contract_vulns = val_dataloader.dataset.data[0]['contract_vulnerabilities'].cpu().numpy()
true_line_vulns = val_dataloader.dataset.data[0]['vulnerable_lines'].cpu().numpy()

In [8]:
pred_results = analyzer.detect_vulnerabilities(source_code, threshold=0.5)
            
# Extract predictions
pred_contract_vulns = pred_results['contract_vulnerabilities']
pred_line_vulns = pred_results['line_vulnerabilities']
pred_contract_probs = pred_results['contract_probabilities'][0]

# Convert predictions to numpy arrays for comparison
pred_contract_array = np.array([
    1 if pred_contract_vulns[vuln_type] else 0 
    for vuln_type in analyzer.vulnerability_types
])

# Calculate metrics
contract_accuracy = np.mean(pred_contract_array == true_contract_vulns)
contract_precision = calculate_precision(true_contract_vulns, pred_contract_array)
contract_recall = calculate_recall(true_contract_vulns, pred_contract_array)
contract_f1 = calculate_f1_score(contract_precision, contract_recall)

# Line-level analysis (simplified)
line_accuracy = calculate_line_accuracy(true_line_vulns, pred_line_vulns)

generated_contracts = analyzer.generate_synthetic_contract(
    contract_template=source_code,
    num_contracts=1,
    temperature=0.9
)
generated_contract = generated_contracts[0] if generated_contracts else "Generation failed"
            

result = {
    'contract_name': contract_name,
    'batch_idx': batch_idx,
    'source_code': source_code,
    'true_contract_vulnerabilities': true_contract_vulns.tolist(),
    'predicted_contract_vulnerabilities': pred_contract_array.tolist(),
    'contract_probabilities': pred_contract_probs,
    'contract_accuracy': float(contract_accuracy),
    'contract_precision': float(contract_precision),
    'contract_recall': float(contract_recall),
    'contract_f1': float(contract_f1),
    'line_accuracy': float(line_accuracy),
    'vulnerability_details': get_vulnerability_details(analyzer, true_contract_vulns, pred_contract_array, pred_contract_probs),
    'generated_contract': generated_contracts
}

Contract vuln logits shape: torch.Size([1, 8])
Line vuln logits shape: torch.Size([1, 1024, 8])
Number of lines in contract: 90
Contract predictions shape: (1, 8)
Line predictions shape: (1, 1024, 8)
Generating contract 1/1...
Memory shape after encoding: torch.Size([1, 1024, 768])
Starting generation with max_len: 1024
Initial tgt shape: torch.Size([1, 1])
Initial tgt: tensor([[0]], device='cuda:0')
Step 0: Logits shape: torch.Size([1, 50265])
Step 0: Logits range: [-21.5992, -2.5827]
Step 0: Probs sum: 1.0000
Step 0: Probs max: 1.0000
Step 0: Next token: 4862
Step 0: Current sequence length: 2
Step 0: Last 5 tokens: [0, 4862]
Step 1: Logits shape: torch.Size([1, 50265])
Step 1: Logits range: [-17.6600, 1.5444]
Step 1: Probs sum: 1.0000
Step 1: Probs max: 1.0000
Step 1: Next token: 1073
Step 2: Logits shape: torch.Size([1, 50265])
Step 2: Logits range: [-13.9757, 5.1368]
Step 2: Probs sum: 1.0000
Step 2: Probs max: 1.0000
Step 2: Next token: 1916
Step 3: Logits shape: torch.Size([1, 5

In [9]:
result

{'contract_name': 'Contract_0',
 'batch_idx': 0,
 'source_code': "pragma solidity ^0.4.19;\r\n\r\ncontract BaseToken {\r\n    string public name;\r\n    string public symbol;\r\n    uint8 public decimals;\r\n    uint256 public totalSupply;\r\n\r\n    mapping (address => uint256) public balanceOf;\r\n    mapping (address => mapping (address => uint256)) public allowance;\r\n\r\n    event Transfer(address indexed from, address indexed to, uint256 value);\r\n    event Approval(address indexed owner, address indexed spender, uint256 value);\r\n\r\n    function _transfer(address _from, address _to, uint _value) internal {\r\n        require(_to != 0x0);\r\n        require(balanceOf[_from] >= _value);\r\n        require(balanceOf[_to] + _value > balanceOf[_to]);\r\n        uint previousBalances = balanceOf[_from] + balanceOf[_to];\r\n        balanceOf[_from] -= _value;\r\n        balanceOf[_to] += _value;\r\n        assert(balanceOf[_from] + balanceOf[_to] == previousBalances);\r\n        Tr

In [9]:
print(source_code)

pragma solidity ^0.4.19;

contract BaseToken {
    string public name;
    string public symbol;
    uint8 public decimals;
    uint256 public totalSupply;

    mapping (address => uint256) public balanceOf;
    mapping (address => mapping (address => uint256)) public allowance;

    event Transfer(address indexed from, address indexed to, uint256 value);
    event Approval(address indexed owner, address indexed spender, uint256 value);

    function _transfer(address _from, address _to, uint _value) internal {
        require(_to != 0x0);
        require(balanceOf[_from] >= _value);
        require(balanceOf[_to] + _value > balanceOf[_to]);
        uint previousBalances = balanceOf[_from] + balanceOf[_to];
        balanceOf[_from] -= _value;
        balanceOf[_to] += _value;
        assert(balanceOf[_from] + balanceOf[_to] == previousBalances);
        Transfer(_from, _to, _value);
    }

    function transfer(address _to, uint256 _value) public returns (bool success) {
        _trans

In [11]:
print(generated_contracts[0])

pragma solidity ^0.4.19;

contract BaseToken {
    string public name;
    string public symbol;
    uint8 public decimals;
    uint256 public totalSupply;

    mapping (address => uint256) public balanceOf;
    mapping (address => mapping (address => uint256)) public allowance;

    event Transfer(address indexed from, address indexed to, uint256 value);
    event Approval(address indexed owner, address indexed spender, uint256 value);

    function _transfer(address _from, address _to, uint _value) internal {
        Our(_to!= 0x0);
        ratios require(balanceOf[_from] >= _value);
       require(balanceOf[_to] + _value > balanceOf[_to]);
       John uint previousBalances = balanceOf[_from] + balanceOf[_to];
       balanceOf[_from] -= _value;
       beitOf[_to] += _value;
       db(balanceOf[_from] + balanceOf[_to] == previousBalances);
       athan(_from, _to, _value);
    }

    function transfer(address _to, uint256 _value) public returns (bool success) {
        eras(msg.sender

In [None]:
analyzer = SmartContractAnalyzer("checkpoints/best_model.pt")
result = analyzer.test_generation_simple(template)
print(result)

In [13]:
output_layer = analyzer.model.output_layer
weights = output_layer.weight.data
bias = output_layer.bias.data

print(f"Output layer shape: {weights.shape}")
print(f"Output layer weight stats:")
print(f"  Mean: {weights.mean().item():.6f}")
print(f"  Std: {weights.std().item():.6f}")
print(f"  Min: {weights.min().item():.6f}")
print(f"  Max: {weights.max().item():.6f}")

print(f"Output layer bias stats:")
print(f"  Mean: {bias.mean().item():.6f}")
print(f"  Std: {bias.std().item():.6f}")
print(f"  Min: {bias.min().item():.6f}")
print(f"  Max: {bias.max().item():.6f}")

# Check if weights are all the same (indicating poor training)
weight_variance = weights.var().item()
bias_variance = bias.var().item()

print(f"Weight variance: {weight_variance:.6f}")
print(f"Bias variance: {bias_variance:.6f}")

if weight_variance < 1e-6:
    print("⚠️  WARNING: Output layer weights have very low variance - model may not be properly trained for generation")

if bias_variance < 1e-6:
    print("⚠️  WARNING: Output layer bias has very low variance - model may not be properly trained for generation")

# Test a simple forward pass
print("\n=== Testing Simple Forward Pass ===")
test_input = torch.randint(0, 1000, (1, 10)).to(analyzer.device)
test_attention = torch.ones(1, 10).to(analyzer.device)

try:
    with torch.no_grad():
        # Test encoder
        test_emb = analyzer.model.embedding(test_input) * math.sqrt(analyzer.model.d_model)
        test_emb = analyzer.model.embedding_dropout(test_emb)
        test_emb = analyzer.model.embedding_norm(test_emb)
        test_emb = analyzer.model.pos_encoder(test_emb.transpose(0, 1)).transpose(0, 1)
        
        memory = analyzer.model.encoder(test_emb, src_key_padding_mask=~test_attention.bool())
        
        # Test decoder with single token
        tgt = torch.full((1, 1), 0, dtype=torch.long, device=analyzer.device)
        tgt_emb = analyzer.model.embedding(tgt) * math.sqrt(analyzer.model.d_model)
        tgt_emb = analyzer.model.embedding_dropout(tgt_emb)
        tgt_emb = analyzer.model.embedding_norm(tgt_emb)
        tgt_emb = analyzer.model.pos_encoder(tgt_emb.transpose(0, 1)).transpose(0, 1)
        
        out = analyzer.model.decoder(
            tgt_emb,
            memory,
            tgt_mask=None,
            memory_key_padding_mask=~test_attention.bool()
        )
        
        out = analyzer.model.output_norm(out)
        out = analyzer.model.output_dropout(out)
        logits = analyzer.model.output_layer(out[:, -1, :])
        
        print(f"Test logits shape: {logits.shape}")
        print(f"Test logits range: [{logits.min().item():.4f}, {logits.max().item():.4f}]")
        
        # Check if logits are reasonable
        probs = torch.softmax(logits, dim=-1)
        print(f"Test probs sum: {probs.sum().item():.4f}")
        print(f"Test probs max: {probs.max().item():.4f}")
        
        # Check top predictions
        top_probs, top_indices = torch.topk(probs, 5, dim=-1)
        print(f"Top 5 token probabilities: {top_probs[0].tolist()}")
        print(f"Top 5 token indices: {top_indices[0].tolist()}")
        
        # Check if token 1 is always the top prediction
        if top_indices[0, 0] == 1:
            print("⚠️  WARNING: Token 1 is the top prediction - this explains the all-1 generation issue")
        else:
            print("✅ Token 1 is not the top prediction - model should generate diverse tokens")
            
except Exception as e:
    print(f"Error during test forward pass: {str(e)}")

print("=== Diagnosis Complete ===")

Output layer shape: torch.Size([50265, 768])
Output layer weight stats:
  Mean: -0.004437
  Std: 0.102096
  Min: -1.941332
  Max: 2.517758
Output layer bias stats:
  Mean: -3.233994
  Std: 0.944389
  Min: -3.701849
  Max: 2.135393
Weight variance: 0.010424
Bias variance: 0.891871

=== Testing Simple Forward Pass ===
Test logits shape: torch.Size([1, 50265])
Test logits range: [-24.1880, -9.9746]
Test probs sum: 1.0000
Test probs max: 0.1560
Top 5 token probabilities: [0.15604272484779358, 0.053808506578207016, 0.001391868805512786, 0.0006813530344516039, 0.00045423494884744287]
Top 5 token indices: [1, 163, 1437, 846, 10975]
=== Diagnosis Complete ===


In [None]:

    results = analyzer.detect_vulnerabilities(contract_code)
    
    print("\n=== Smart Contract Vulnerability Analysis ===")
    
    # Contract-level vulnerabilities
    print("\n📋 Contract-Level Vulnerabilities:")
    contract_vulns = results['contract_vulnerabilities']
    for vuln_type, is_vulnerable in contract_vulns.items():
        status = "❌ VULNERABLE" if is_vulnerable else "✅ SAFE"
        print(f"  {vuln_type}: {status}")
    
    # Line-level vulnerabilities
    print("\n📝 Line-Level Vulnerabilities:")
    line_vulns = results['line_vulnerabilities']
    lines = contract_code.strip().split('\n')
    
    for line_num, line in enumerate(lines):
        if line.strip():  # Only show non-empty lines
            line_vulns_for_line = line_vulns.get(line_num, {})
            vulnerable_types = [vuln_type for vuln_type, is_vuln in line_vulns_for_line.items() if is_vuln]
            
            if vulnerable_types:
                print(f"  Line {line_num + 1}: ❌ {', '.join(vulnerable_types)}")
                print(f"    Code: {line.strip()}")
            else:
                print(f"  Line {line_num + 1}: ✅ SAFE")
    
    # Probabilities summary
    print("\n📊 Vulnerability Probabilities:")
    contract_probs = results['contract_probabilities'][0]  # First (and only) contract
    for i, vuln_type in enumerate(analyzer.vulnerability_types):
        prob = contract_probs[i]
        print(f"  {vuln_type}: {prob:.4f} ({prob*100:.2f}%)")
    
    # Generate synthetic contracts
    print("\n🔧 Generating Synthetic Contracts...")
    synthetic_contracts = analyzer.generate_synthetic_contract(
        contract_template=contract_code,
        num_contracts=2
    )
    
    print("\n📄 Generated Contracts:")
    for i, contract in enumerate(synthetic_contracts, 1):
        print(f"\n--- Contract {i} ---")
        print(contract)
        
        # Analyze generated contract
        print(f"\n🔍 Analysis of Generated Contract {i}:")
        gen_results = analyzer.detect_vulnerabilities(contract)
        gen_contract_vulns = gen_results['contract_vulnerabilities']
        
        for vuln_type, is_vulnerable in gen_contract_vulns.items():
            status = "❌ VULNERABLE" if is_vulnerable else "✅ SAFE"
            print(f"  {vuln_type}: {status}")


In [3]:
main() 

Using device: cuda




Model loaded from checkpoints_v2_2048_output/best_model_epoch_126.pt
Training epoch: 126
Best validation loss: 1.448835600167513
🚀 Smart Contract Vulnerability Detection Model Loaded Successfully!

📋 SC-T-GAN: Smart Contract Transformer GAN - Vulnerability Detection Mode
--------------------------------------------------------------------------------

VULNERABILITY ANALYSIS REPORT: VulnerableContract

📋 CONTRACT-LEVEL VULNERABILITIES:
----------------------------------------
ARTHM        | 🔴 VULNERABLE    | Probability: 0.699
DOS          | 🔴 VULNERABLE    | Probability: 0.325
LE           | 🔴 VULNERABLE    | Probability: 0.336
RENT         | 🔴 VULNERABLE    | Probability: 0.447
TimeM        | 🔴 VULNERABLE    | Probability: 0.322
TimeO        | 🔴 VULNERABLE    | Probability: 0.313
Tx-Origin    | 🟢 SAFE          | Probability: 0.002
UE           | 🟢 SAFE          | Probability: 0.287

📄 LINE-LEVEL VULNERABILITIES:
----------------------------------------

📊 SUMMARY:
--------------------