In [4]:
!pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
from inference import *

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Model imports
from model import SmartContractTransformer
from inference import SmartContractAnalyzer

# Optional but useful imports
import numpy as np
from tqdm import tqdm  # for progress bars
import logging

from transformers import AutoTokenizer
import json
import os
import pandas as pd
from typing import Dict, List, Tuple, Any
import re
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score

  _torch_pytree._register_pytree_node(


# 0. AUX Functions:

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def parse_solidity_to_ast(code: str) -> Dict[str, Any]:
    """
    Parse Solidity code into a simplified AST structure
    """
    def extract_contract_info(code: str) -> Dict[str, Any]:
        # Extract contract name
        contract_match = re.search(r'contract\s+(\w+)', code)
        contract_name = contract_match.group(1) if contract_match else "Unknown"
        
        # Extract functions
        functions = []
        function_pattern = r'function\s+(\w+)\s*\(([^)]*)\)\s*(?:public|private|internal|external)?\s*(?:view|pure|payable)?\s*(?:returns\s*\(([^)]*)\))?\s*{'
        for match in re.finditer(function_pattern, code):
            func_name = match.group(1)
            params = match.group(2).split(',') if match.group(2) else []
            returns = match.group(3).split(',') if match.group(3) else []
            
            functions.append({
                'name': func_name,
                'parameters': [p.strip() for p in params],
                'returns': [r.strip() for r in returns]
            })
        
        # Extract state variables
        variables = []
        var_pattern = r'(?:uint|address|string|bool|mapping)\s+(?:\w+)\s+(\w+)'
        for match in re.finditer(var_pattern, code):
            variables.append(match.group(1))
        
        return {
            'type': 'Contract',
            'name': contract_name,
            'functions': functions,
            'variables': variables
        }
    
    try:
        # Clean the code
        code = re.sub(r'//.*?\n|/\*.*?\*/', '', code)  # Remove comments
        code = re.sub(r'\s+', ' ', code)  # Normalize whitespace
        
        # Parse the code
        ast = extract_contract_info(code)
        return ast
    except Exception as e:
        print(f"Error parsing code: {str(e)}")
        return None

def prepare_code2vec_input(ast: Dict[str, Any]) -> List[str]:
    """
    Convert AST to codeBert input format
    """
    paths = []
    
    def extract_paths(node: Dict[str, Any], current_path: List[str] = None):
        if current_path is None:
            current_path = []
            
        # Add current node to path
        if 'name' in node:
            current_path.append(node['name'])
            
        # Process functions
        if 'functions' in node:
            for func in node['functions']:
                func_path = current_path + [func['name']]
                paths.append(' '.join(func_path))
                
                # Add parameter paths
                for param in func['parameters']:
                    param_path = func_path + [param]
                    paths.append(' '.join(param_path))
                
                # Add return paths
                for ret in func['returns']:
                    ret_path = func_path + [ret]
                    paths.append(' '.join(ret_path))
        
        # Process variables
        if 'variables' in node:
            for var in node['variables']:
                var_path = current_path + [var]
                paths.append(' '.join(var_path))
    
    extract_paths(ast)
    return paths

class SmartContractVulnerabilityDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 1024,
        split: str = "train",
        vulnerability_types: List[str] = None
    ):
        """
        Args:
            data_path: Path to the CSV file containing the dataset
            tokenizer: Tokenizer for encoding the source code
            max_length: Maximum sequence length
            split: "train" or "val" to specify which split to load
            vulnerability_types: List of vulnerability types to consider
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.split = split
        self.vulnerability_types = vulnerability_types or [
            'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
        ]
        
        # Load the dataset
        self.data = self._load_dataset(data_path)
        
    def _load_dataset(self, data_path: str) -> List[Dict]:
        """Load and preprocess the dataset from CSV"""
        dataset = []
        
        # Read the CSV file
        df = pd.read_csv(data_path)
        
        # Split into train/val if needed
        if self.split == "train":
            df = df.sample(frac=0.8, random_state=42)
        else:
            df = df.sample(frac=0.2, random_state=42)
        
        # Process each contract
        for _, row in df.iterrows():
            try:
                source_code = row['source_code']
                contract_name = row['contract_name']
                
                # Parse AST and get paths
                ast = parse_solidity_to_ast(source_code)
                ast_paths = prepare_code2vec_input(ast) if ast else []
                ast_path_text = ' '.join(ast_paths)
                
                # Split source code into lines
                lines = source_code.split('\n')
                
                # Create token-to-line mapping
                token_to_line = []
                current_line = 0
                
                # Tokenize each line separately to maintain mapping
                for line in lines:
                    line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
                    token_to_line.extend([current_line] * len(line_tokens))
                    current_line += 1
                
                # Add special tokens
                token_to_line = [0] + token_to_line + [0]  # [CLS] and [SEP] tokens
                
                # Truncate if too long
                if len(token_to_line) > self.max_length:
                    token_to_line = token_to_line[:self.max_length]
                
                # Pad if too short
                if len(token_to_line) < self.max_length:
                    token_to_line.extend([0] * (self.max_length - len(token_to_line)))
                
                # Create multi-label line labels for each vulnerability type
                line_labels = self._create_multi_label_line_labels(source_code, row)
                
                # Create contract-level vulnerability labels
                contract_labels = self._create_contract_vulnerability_labels(row)
                
                # Tokenize the source code
                encoding = self.tokenizer(
                    source_code,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Tokenize AST paths
                ast_encoding = self.tokenizer(
                    ast_path_text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
                
                # Convert line labels to tensor and ensure consistent shape
                vuln_tensor = torch.zeros((len(self.vulnerability_types), self.max_length), dtype=torch.long)
                for i, labels in enumerate(line_labels):
                    if len(labels) > self.max_length:
                        labels = labels[:self.max_length]
                    vuln_tensor[i, :len(labels)] = torch.tensor(labels, dtype=torch.long)
                
                # Convert contract labels to tensor
                contract_vuln_tensor = torch.tensor(contract_labels, dtype=torch.long)
                
                # Convert token_to_line to tensor
                token_to_line_tensor = torch.tensor(token_to_line, dtype=torch.long)
                
                # Ensure attention masks are boolean
                attention_mask = encoding['attention_mask'].squeeze(0).bool()
                ast_attention_mask = ast_encoding['attention_mask'].squeeze(0).bool()
                
                # Ensure input_ids are the right length
                input_ids = encoding['input_ids'].squeeze(0)
                ast_input_ids = ast_encoding['input_ids'].squeeze(0)
                
                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                if len(ast_input_ids) > self.max_length:
                    ast_input_ids = ast_input_ids[:self.max_length]
                
                # Pad if necessary
                if len(input_ids) < self.max_length:
                    input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - len(input_ids)))
                if len(ast_input_ids) < self.max_length:
                    ast_input_ids = torch.nn.functional.pad(ast_input_ids, (0, self.max_length - len(ast_input_ids)))
                
                dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'ast_input_ids': ast_input_ids,
                    'ast_attention_mask': ast_attention_mask,
                    'vulnerable_lines': vuln_tensor,
                    'contract_vulnerabilities': contract_vuln_tensor,
                    'token_to_line': token_to_line_tensor,
                    'source_code': source_code,
                    'contract_name': contract_name
                })
            except Exception as e:
                print(f"Error processing contract {contract_name}: {str(e)}")
                continue
        
        return dataset
    
    def _create_contract_vulnerability_labels(self, row: pd.Series) -> List[int]:
        """Create contract-level vulnerability labels"""
        contract_labels = []
        for vuln_type in self.vulnerability_types:
            # Check if contract has this vulnerability type
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Contract is vulnerable if it has any vulnerable lines
            has_vulnerability = len(vuln_lines) > 0
            contract_labels.append(1 if has_vulnerability else 0)
        
        return contract_labels
    
    def _create_multi_label_line_labels(self, source_code: str, row: pd.Series) -> List[List[int]]:
        """Create multi-label line labels for each vulnerability type"""
        total_lines = len(source_code.split('\n'))
        line_labels = {vuln_type: [0] * total_lines for vuln_type in self.vulnerability_types}
        
        # Process each vulnerability type
        for vuln_type in self.vulnerability_types:
            vuln_lines = row[f'{vuln_type}_lines']
            if isinstance(vuln_lines, str):
                try:
                    vuln_lines = eval(vuln_lines)
                except:
                    vuln_lines = [vuln_lines]
            
            # Process each vulnerable line/snippet
            for line_or_snippet in vuln_lines:
                if isinstance(line_or_snippet, int):
                    # If it's a line number, mark that line
                    if 0 <= line_or_snippet < total_lines:
                        line_labels[vuln_type][line_or_snippet] = 1
                else:
                    # If it's a code snippet, find matching lines
                    source_lines = source_code.split('\n')
                    for i, line in enumerate(source_lines):
                        # Clean both the line and snippet for comparison
                        clean_line = re.sub(r'\s+', ' ', line.strip())
                        clean_snippet = re.sub(r'\s+', ' ', str(line_or_snippet).strip())
                        if clean_snippet in clean_line:
                            line_labels[vuln_type][i] = 1
        
        # Convert to list format
        return [line_labels[vuln_type] for vuln_type in self.vulnerability_types]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable length inputs
    """
    # Get the maximum length in this batch for each type of tensor
    max_input_len = max(item['input_ids'].size(0) for item in batch)
    
    # Pad all tensors to their respective maximum lengths
    padded_batch = {
        'input_ids': torch.stack([
            torch.nn.functional.pad(item['input_ids'], (0, max_input_len - item['input_ids'].size(0)))
            for item in batch
        ]),
        'attention_mask': torch.stack([
            torch.nn.functional.pad(item['attention_mask'], (0, max_input_len - item['attention_mask'].size(0)))
            for item in batch
        ]),
        'ast_input_ids': torch.stack([item['ast_input_ids'] for item in batch]),
        'ast_attention_mask': torch.stack([item['ast_attention_mask'] for item in batch]),
        'vulnerable_lines': torch.stack([item['vulnerable_lines'] for item in batch]),
        'contract_vulnerabilities': torch.stack([item['contract_vulnerabilities'] for item in batch]),
        'token_to_line': torch.stack([item['token_to_line'] for item in batch]),
        'source_code': [item['source_code'] for item in batch],
        'contract_name': [item['contract_name'] for item in batch]
    }
    
    return padded_batch


# 0 Constants:

In [3]:
DATA_PATH = "contract_sources_with_vulnerabilities_2048_token_size.csv"
MODEL_PATH = "checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt"
TOKENIZER_NAME = "microsoft/codebert-base"
VULNERABILITY_TYPES = ['ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE']
DEVICE = torch.device("cuda")
MODEL_LINE_CODE_VULNERABILITY_THRESHOLD = 0.2

# 1 Load Dataset:

In [4]:
def create_dataloaders(
    data_path: str,
    tokenizer: AutoTokenizer,
    batch_size: int = 8,
    max_length: int = 1024,
    num_workers: int = 4,
    vulnerability_types: List[str] = None
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    
    val_dataset = SmartContractVulnerabilityDataset(
        data_path=data_path,
        tokenizer=tokenizer,
        max_length=max_length,
        split="val",
        vulnerability_types=vulnerability_types
    )
    
    dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn
    )
    
    return  dataloader

In [5]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

val_dataloader = create_dataloaders(
    data_path=DATA_PATH,
    tokenizer=tokenizer,
    batch_size=8,
    max_length=1024,
    vulnerability_types=VULNERABILITY_TYPES
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors


In [6]:
val_dataloader.dataset.data[0]['vulnerable_lines'][7]

tensor([0, 0, 0,  ..., 0, 0, 0])

# 1. Contract Generation

## 1.2 Load Model:

In [7]:
model_path=MODEL_PATH
device = DEVICE
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

model = SmartContractTransformer(
        d_model=768,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        max_length=1024,
        vocab_size=tokenizer.vocab_size,
        num_vulnerability_types=8,
        use_gan=True 
    )

checkpoint = torch.load(model_path, map_location=device)

if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from {model_path}")
    print(f"Training epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"Best validation loss: {checkpoint.get('val_loss', 'Unknown')}")
    print(f"Training config: GAN={checkpoint.get('use_gan', 'Unknown')}")
else:
    # Direct state dict
    model.load_state_dict(checkpoint)
    print(f"Model loaded from {model_path}")

model.to(device)
model.eval()
    
vulnerability_types = [
    'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
]



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights
Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt
Training epoch: 106
Best validation loss: 0.773994717746973
Training config: GAN=True


In [8]:
print(val_dataloader.dataset.data[5]['source_code'])

/**
 * Source Code first verified at https://etherscan.io on Tuesday, December 18, 2018
 (UTC) */

pragma solidity ^0.4.23;

library SafeMath {

  /**
  * @dev Multiplies two numbers, throws on overflow.
  */
  function mul(uint256 a, uint256 b) internal pure returns (uint256 c) {

    if (a == 0) {
      return 0;
    }

    c = a * b;
    assert(c / a == b);
    return c;
  }

  /**
  * @dev Integer division of two numbers, truncating the quotient.
  */
  function div(uint256 a, uint256 b) internal pure returns (uint256) {

    return a / b;
  }

  /**
  * @dev Subtracts two numbers, throws on overflow (i.e. if subtrahend is greater than minuend).
  */
  function sub(uint256 a, uint256 b) internal pure returns (uint256) {
    assert(b <= a);
    return a - b;
  }

  /**
  * @dev Adds two numbers, throws on overflow.
  */
  function add(uint256 a, uint256 b) internal pure returns (uint256 c) {
    c = a + b;
    assert(c >= a);
    return c;
  }
}


contract ERC20Basic {
    
  functi

In [9]:
model = SmartContractTransformer(
        d_model=768,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        max_length=1024,
        vocab_size=tokenizer.vocab_size,
        num_vulnerability_types=8,
        use_gan=True 
    )

checkpoint = torch.load(model_path, map_location=device)

if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded from {model_path}")
    print(f"Training epoch: {checkpoint.get('epoch', 'Unknown')}")
    print(f"Best validation loss: {checkpoint.get('val_loss', 'Unknown')}")
    print(f"Training config: GAN={checkpoint.get('use_gan', 'Unknown')}")
else:
    # Direct state dict
    model.load_state_dict(checkpoint)
    print(f"Model loaded from {model_path}")

contract_code = val_dataloader.dataset.data[5]['source_code']
ast = parse_solidity_to_ast(contract_code)
ast_paths = prepare_code2vec_input(ast) if ast else []
ast_path_text = ' '.join(ast_paths)

# Tokenize inputs
contract_encoding = tokenizer(
    contract_code,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

ast_encoding = tokenizer(
    ast_path_text,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

input_ids = contract_encoding['input_ids'].to(device)
attention_mask = contract_encoding['attention_mask'].to(device)
ast_input_ids = ast_encoding['input_ids'].to(device)
ast_attention_mask = ast_encoding['attention_mask'].to(device)

model._debug_mode = False

with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ast_input_ids=ast_input_ids,
        ast_attention_mask=ast_attention_mask,
        target_ids=input_ids,  # Use input_ids as target for inference
        token_to_line=None
    )



DEBUG: Initialized line feature extractor layer 1 with small random weights
DEBUG: Initialized line feature extractor layer 3 with small random weights
DEBUG: Initialized custom line feature extractor with small weights
Model loaded from checkpoints_v5_2048_output/best_model_augmented_gan_epoch_106.pt
Training epoch: 106
Best validation loss: 0.773994717746973
Training config: GAN=True


In [10]:
model.to(device)
model.eval()
    
vulnerability_types = [
    'ARTHM', 'DOS', 'LE', 'RENT', 'TimeM', 'TimeO', 'Tx-Origin', 'UE'
]

In [11]:
contract_code = val_dataloader.dataset.data[5]['source_code']
ast = parse_solidity_to_ast(contract_code)
ast_paths = prepare_code2vec_input(ast) if ast else []
ast_path_text = ' '.join(ast_paths)

# Tokenize inputs
contract_encoding = tokenizer(
    contract_code,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

ast_encoding = tokenizer(
    ast_path_text,
    max_length=1024,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)

In [12]:
input_ids = contract_encoding['input_ids'].to(device)
attention_mask = contract_encoding['attention_mask'].to(device)
ast_input_ids = ast_encoding['input_ids'].to(device)
ast_attention_mask = ast_encoding['attention_mask'].to(device)

In [14]:
model._debug_mode = False

with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ast_input_ids=ast_input_ids,
        ast_attention_mask=ast_attention_mask,
        target_ids=input_ids,  # Use input_ids as target for inference
        token_to_line=None
    )

In [15]:
outputs

{'logits': tensor([[-11.3019,  -8.9195,   3.9378,  ..., -11.2757, -11.2904, -11.2999],
         [-12.9732,  -5.3152,   4.2387,  ..., -12.9763, -12.8868, -12.9534],
         [ -1.0804,   4.8225,   9.7027,  ...,  -1.1027,  -1.0450,  -1.1748],
         ...,
         [-16.8429,  -9.4144,   3.8516,  ..., -16.7851, -16.8252, -16.8731],
         [-10.9757,  -6.3530,   8.9934,  ..., -10.9440, -11.0186, -11.0578],
         [ -7.8117,  -4.9354,  22.3554,  ...,  -7.8987,  -7.8499,  -7.9367]],
        device='cuda:0'),
 'target_ids': tensor([49795, 50121, 50118,  ..., 44547, 18134,     2], device='cuda:0'),
 'contract_vulnerability_logits': tensor([[ 4.8884e-01, -2.7327e-01, -2.4371e-01,  1.0683e-04, -2.8266e-01,
          -3.0378e-01, -3.6681e+00, -3.8108e-01]], device='cuda:0'),
 'line_vulnerability_logits': tensor([[[-46.0587, -82.9822, -96.5135,  ..., -76.5406, -90.5835, -93.7482],
          [-34.9762, -70.2816, -79.0512,  ..., -63.8398, -74.7924, -72.7562],
          [-30.4257, -69.3439, -73.

In [8]:
import torch
import torch.nn.functional as F
import math
import numpy as np # Added for token frequency debugging


In [13]:
def generate_from_working_logits(model, outputs, tokenizer, max_length=2048, temperature=0.7, top_k=50, top_p=0.95):
    """
    Generate synthetic smart contract from the logits produced by your working forward pass.
    This function uses the actual logits from your model outputs to generate tokens.
    
    Args:
        model: The trained SmartContractTransformer model
        outputs: Outputs from your working forward pass containing logits
        tokenizer: The tokenizer for decoding
        max_length: Maximum generation length
        temperature: Sampling temperature
        top_k: Top-k sampling parameter
        top_p: Nucleus sampling parameter
    
    Returns:
        Generated source code as string
    """
    device = next(model.parameters()).device
    
    # Get the logits from your working outputs
    if 'logits' not in outputs:
        print("Error: logits not found in model outputs.")
        print("Available keys:", list(outputs.keys()))
        return None
    
    logits = outputs['logits']  # [batch_size * seq_len, vocab_size]
    print(f"Logits shape: {logits.shape}")
    
    # Reshape logits to [batch_size, seq_len, vocab_size]
    batch_size = 1  # Assuming single batch
    seq_len = logits.size(0) // batch_size
    vocab_size = logits.size(1)
    
    logits_reshaped = logits.view(batch_size, seq_len, vocab_size)
    print(f"Reshaped logits: {logits_reshaped.shape}")
    
    # Apply temperature scaling
    logits_reshaped = logits_reshaped / temperature
    
    # Apply top-k filtering
    if top_k > 0:
        top_k_logits, top_k_indices = torch.topk(logits_reshaped, top_k, dim=-1)
        logits_mask = torch.full_like(logits_reshaped, float('-inf'))
        logits_mask.scatter_(-1, top_k_indices, top_k_logits)
        logits_reshaped = logits_mask
    
    # Apply nucleus sampling
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits_reshaped, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        logits_reshaped = logits_reshaped.masked_fill(indices_to_remove, float('-inf'))
    
    # Sample tokens from the logits
    probs = torch.softmax(logits_reshaped, dim=-1)
    sampled_tokens = torch.multinomial(probs.view(-1, vocab_size), num_samples=1)
    sampled_tokens = sampled_tokens.view(batch_size, seq_len)
    
    print(f"Sampled tokens shape: {sampled_tokens.shape}")
    print(f"Sampled tokens (first 20): {sampled_tokens[0, :20].cpu().numpy()}")
    print(f"Sampled tokens (last 20): {sampled_tokens[0, -20:].cpu().numpy()}")
    
    # Decode the generated tokens
    generated_tokens = sampled_tokens[0].cpu().numpy()
    
    # Filter out padding and special tokens
    special_tokens = {0, 1, 2, 3}  # PAD, UNK, CLS, SEP
    filtered_tokens = [t for t in generated_tokens if t not in special_tokens and t != tokenizer.pad_token_id]
    
    print(f"Filtered tokens (first 20): {filtered_tokens[:20]}")
    print(f"Total filtered tokens: {len(filtered_tokens)}")
    
    if filtered_tokens:
        generated_code = tokenizer.decode(
            filtered_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        print(f"Generated code length: {len(generated_code)}")
        return generated_code
    else:
        print("No meaningful tokens found!")
        return ""


In [14]:
generated_code = generate_from_working_logits(
    model=model,
    outputs=outputs,  # Your existing outputs with logits
    tokenizer=tokenizer,
    temperature=0.1
)

Logits shape: torch.Size([1023, 50265])
Reshaped logits: torch.Size([1, 1023, 50265])
Sampled tokens shape: torch.Size([1, 1023])
Sampled tokens (first 20): [ 4862  1073  1916  2705  1571 37249   288     4   306     4  1978   131
 50121 50118 50121 50118 49051 50121 50118  1437]
Sampled tokens (last 20): [ 1437  1437  1437  1437 46486     4 40095  1640  9226 41014  4397 50121
 50118  1437  1437  1437  1437  1437  1437     2]
Filtered tokens (first 20): [4862, 1073, 1916, 2705, 1571, 37249, 288, 4, 306, 4, 1978, 131, 50121, 50118, 50121, 50118, 49051, 50121, 50118, 1437]
Total filtered tokens: 1022
Generated code length: 1580


In [15]:
generated_code

'pragma solidity ^0.4.24;\r\n\r\n/*\r\n    ��退�逎��了���交易亀�挜��，批戎。������仧�扎��不胡�要�������送�诡�。�戌我退���玎���掜接用了�OS區�敀��戼�戎��于�掜�����OS贬��敞�。��甴���逨�OS��\r\n     ����了��仨亡涯捡�仠明���尯接甡�。�从�掄\uf39c賡��与���戯接甡�。�截��。�����釨亡����，�退�逑�掯掁��方��祿���瓈����挕�����鸡���\uf344�app洣���泎��\r\n    �逈�專�Dapp接�����站釜�要������玦戨倀��贪��退��息�昳注輌戠甤昹君甀�立令昈秄逼�����戎甄我退��尶�注亁俴昅注����句叠����贪家�泎�站亶����OS贬���尴�贬�������O，��。�������尼���。�小。�戯��採漌。�\n    ���玄�胄。��\uf3b6�����仿甿.�\r\r\n\r\n    � a specific-like exchange. to our a betting-freeze with no fees and no fees resources.\r\n     the same time, and investors, of the blockOS,. of and the block results of not on the EOS of the E of theOS that of\r\n     address information is E to be transparent and send. The betting of the lottery are not and can and can be transparent.\r\n     thanks to check,, we can be _-chain, we-chain currency,. It is the first isapp. in the first that\r\n   , the first DappVolume websites need to check of website status through the contract is, so you

In [38]:
def generate_from_target_ids(model, outputs, tokenizer):
    """
    Generate synthetic smart contract by decoding the target_ids from your working outputs.
    This is the simplest approach - just decode the target_ids that your model was trained on.
    
    Args:
        model: The trained SmartContractTransformer model
        outputs: Outputs from your working forward pass containing target_ids
        tokenizer: The tokenizer for decoding
    
    Returns:
        Generated source code as string
    """
    if 'target_ids' not in outputs:
        print("Error: target_ids not found in model outputs.")
        print("Available keys:", list(outputs.keys()))
        return None
    
    target_ids = outputs['target_ids']  # [batch_size * seq_len]
    print(f"Target IDs shape: {target_ids.shape}")
    print(f"Target IDs (first 20): {target_ids[:20].cpu().numpy()}")
    print(f"Target IDs (last 20): {target_ids[-20:].cpu().numpy()}")
    
    # Decode the target tokens
    generated_tokens = target_ids.cpu().numpy()
    
    # Filter out padding and special tokens
    special_tokens = {0, 1, 2, 3}  # PAD, UNK, CLS, SEP
    filtered_tokens = [t for t in generated_tokens if t not in special_tokens and t != tokenizer.pad_token_id]
    
    print(f"Filtered target tokens (first 20): {filtered_tokens[:20]}")
    print(f"Total filtered target tokens: {len(filtered_tokens)}")
    
    if filtered_tokens:
        generated_code = tokenizer.decode(
            filtered_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        print(f"Generated code length: {len(generated_code)}")
        return generated_code
    else:
        print("No meaningful target tokens found!")
        return ""


In [39]:
generated_code = generate_from_target_ids(
    model=model,
    outputs=outputs,  # Your existing outputs with target_ids
    tokenizer=tokenizer
)

Target IDs shape: torch.Size([1023])
Target IDs (first 20): [49795 50121 50118  1009  4253  8302    78 13031    23  1205   640 17517
 43511     4  1020    15   294     6   719   504]
Target IDs (last 20): [ 1437   671  1528   131 50121 50118  1437 35524 50121 50118 50121 50118
  1437  5043  2394 10643  1640 44547 18134     2]
Filtered target tokens (first 20): [49795, 50121, 50118, 1009, 4253, 8302, 78, 13031, 23, 1205, 640, 17517, 43511, 4, 1020, 15, 294, 6, 719, 504]
Total filtered target tokens: 1022
Generated code length: 2664


In [40]:
print(generated_code)

/**
 * Source Code first verified at https://etherscan.io on Tuesday, December 18, 2018
 (UTC) */

pragma solidity ^0.4.23;

library SafeMath {

  /**
  * @dev Multiplies two numbers, throws on overflow.
  */
  function mul(uint256 a, uint256 b) internal pure returns (uint256 c) {

    if (a == 0) {
      return 0;
    }

    c = a * b;
    assert(c / a == b);
    return c;
  }

  /**
  * @dev Integer division of two numbers, truncating the quotient.
  */
  function div(uint256 a, uint256 b) internal pure returns (uint256) {

    return a / b;
  }

  /**
  * @dev Subtracts two numbers, throws on overflow (i.e. if subtrahend is greater than minuend).
  */
  function sub(uint256 a, uint256 b) internal pure returns (uint256) {
    assert(b <= a);
    return a - b;
  }

  /**
  * @dev Adds two numbers, throws on overflow.
  */
  function add(uint256 a, uint256 b) internal pure returns (uint256 c) {
    c = a + b;
    assert(c >= a);
    return c;
  }
}


contract ERC20Basic {
    
  functi

In [11]:
from pathlib import Path

In [None]:
print("🚀 Generating Smart Contracts from Validation Dataset")
print("="*60)

output_folder = "./sythetic_smart_contracts"
# Create output folder
output_path = Path(output_folder)
output_path.mkdir(exist_ok=True)
print(f"📁 Output folder: {output_path.absolute()}")

generated_contracts = []
generation_metadata = []

# Generate contracts from validation dataset
num_contracts = len(val_dataloader.dataset.data)
print(f"\n🎯 Generating {num_contracts} contracts from validation dataset...")

contracts_generated = 0
contracts_processed = 0

if val_dataloader:
        # Generate contracts from validation dataset
        print(f"\n🎯 Generating {num_contracts} contracts from validation dataset...")
        
        contracts_generated = 0
        contracts_processed = 0
        
        for batch_idx in range(1, 506):
                
            # Get source code from batch
            source_code = contract_code = val_dataloader.dataset.data[batch_idx]['source_code'] #batch['source_code'][0]  # Assuming batch_size=1
            contract_name = f"contract_{batch_idx:04d}"
            
            print(f"📝 Processing contract {batch_idx + 1}: {len(source_code)} chars")

            contract_code = contract_code = val_dataloader.dataset.data[batch_idx]['source_code'] #batch['source_code'][0]#val_dataloader.dataset.data[5]['source_code']
            ast = parse_solidity_to_ast(contract_code)
            ast_paths = prepare_code2vec_input(ast) if ast else []
            ast_path_text = ' '.join(ast_paths)
            
            # Tokenize inputs
            contract_encoding = tokenizer(
                contract_code,
                max_length=1024,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            ast_encoding = tokenizer(
                ast_path_text,
                max_length=1024,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            input_ids = contract_encoding['input_ids'].to(device)
            attention_mask = contract_encoding['attention_mask'].to(device)
            ast_input_ids = ast_encoding['input_ids'].to(device)
            ast_attention_mask = ast_encoding['attention_mask'].to(device)
            
            model._debug_mode = False
            
            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    ast_input_ids=ast_input_ids,
                    ast_attention_mask=ast_attention_mask,
                    target_ids=input_ids
                )

            generated_contracts_batch = generate_from_working_logits(
                model=model,
                outputs=outputs,  # Your existing outputs with logits
                tokenizer=tokenizer,
                temperature=0.1
            )
            
            if generated_contracts_batch and len(generated_contracts_batch) > 0:
                
                # Save the generated contract to file
                contract_file = output_path / f"{contract_name}_generated.sol"
                with open(contract_file, 'w') as f:
                    f.write(generated_contracts_batch)
                
                # Save the original contract for comparison
                original_file = output_path / f"{contract_name}_original.sol"
                with open(original_file, 'w') as f:
                    f.write(source_code)
                
                # Save metadata
                metadata = {
                    'contract_name': contract_name,
                    'original_length': len(source_code),
                    'generated_length': len(generated_contracts_batch),
                    'original_file': str(original_file),
                    'generated_file': str(contract_file),
                    'batch_index': batch_idx,
                    'generation_success': True
                }
                
                generated_contracts.append(generated_contract)
                generation_metadata.append(metadata)
                contracts_generated += 1
                
                print(f"  ✅ Generated contract for batch {batch_idx}: {len(generated_contracts_batch)} chars")
                print(f"  💾 Saved to: {contract_file}")
                
            else:
                print(f"  ❌ Failed to generate contract {batch_idx + 1}")
                metadata = {
                    'contract_name': contract_name,
                    'original_length': len(source_code),
                    'generated_length': 0,
                    'original_file': str(output_path / f"{contract_name}_original.sol"),
                    'generated_file': None,
                    'batch_index': batch_idx,
                    'generation_success': False,
                    'error': 'Generation returned empty result'
                }
                generation_metadata.append(metadata)

with open(output_path / "generation_summary.json", 'w') as f:
    json.dump(generation_metadata, f, indent=2)

print(f"\n💾 All contracts saved to: {output_path.absolute()}")
print(f"📋 Generation summary saved to: {output_path / 'generation_summary.json'}")


🚀 Generating Smart Contracts from Validation Dataset
📁 Output folder: /home/m20180848/smrt-transformer/07.training-model/sythetic_smart_contracts

🎯 Generating 506 contracts from validation dataset...

🎯 Generating 506 contracts from validation dataset...
📝 Processing contract 2: 4039 chars
Logits shape: torch.Size([1023, 50265])
Reshaped logits: torch.Size([1, 1023, 50265])
Sampled tokens shape: torch.Size([1, 1023])
Sampled tokens (first 20): [ 4862  1073  1916  2705  1571 37249   288     4   306     4  1549   131
 50121 50118 50121 50118 47888 19233 21109 44392]
Sampled tokens (last 20): [19434  4397 50121 50118  1437  1437  1437  1437  1437  1437  1437   671
  1528   131 50121 50118  1437  1437  1437     2]
Filtered tokens (first 20): [4862, 1073, 1916, 2705, 1571, 37249, 288, 4, 306, 4, 1549, 131, 50121, 50118, 50121, 50118, 47888, 19233, 21109, 44392]
Total filtered tokens: 1022
Generated code length: 2766
  ✅ Generated contract for batch 1: 2766 chars
  💾 Saved to: sythetic_smar