In [1]:
#!/usr/bin/env python3
"""
Example usage of the Smart Contract Vulnerability Detection Model

This script demonstrates how to:
1. Load a trained model
2. Detect vulnerabilities in smart contracts
3. Generate synthetic contracts
"""

from test_model import SmartContractAnalyzer, print_vulnerability_report

  _torch_pytree._register_pytree_node(


In [2]:


def main():
    # Example smart contract with potential vulnerabilities
    example_contract = """
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

contract VulnerableContract {
    mapping(address => uint256) public balances;
    uint256 public totalSupply;
    
    // Reentrancy vulnerability
    function withdraw() public {
        uint256 amount = balances[msg.sender];
        require(amount > 0, "Insufficient balance");
        
        // Vulnerable: state change after external call
        (bool success, ) = msg.sender.call{value: amount}("");
        require(success, "Transfer failed");
        
        balances[msg.sender] = 0;  // State change should be before external call
    }
    
    // Integer overflow vulnerability (fixed in Solidity 0.8.0+)
    function add(uint256 a, uint256 b) public pure returns (uint256) {
        return a + b;  // This is safe in Solidity 0.8.0+
    }
    
    // Timestamp dependency vulnerability
    function getRandomNumber() public view returns (uint256) {
        return block.timestamp % 100;  // Vulnerable to timestamp manipulation
    }
    
    // tx.origin vulnerability
    function transfer(address to, uint256 amount) public {
        require(tx.origin == owner, "Only owner can transfer");  // Vulnerable to phishing
        balances[to] += amount;
        balances[msg.sender] -= amount;
    }
    
    address public owner;
    
    constructor() {
        owner = msg.sender;
    }
}
"""

    # Example contract template for generation
    contract_template = """
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

contract TokenContract {
    string public name;
    string public symbol;
    uint8 public decimals;
    uint256 public totalSupply;
    mapping(address => uint256) public balanceOf;
    mapping(address => mapping(address => uint256)) public allowance;
    
    event Transfer(address indexed from, address indexed to, uint256 value);
    event Approval(address indexed owner, address indexed spender, uint256 value);
    
    constructor(string memory _name, string memory _symbol, uint8 _decimals, uint256 _totalSupply) {
        name = _name;
        symbol = _symbol;
        decimals = _decimals;
        totalSupply = _totalSupply;
        balanceOf[msg.sender] = _totalSupply;
    }
    
    function transfer(address to, uint256 value) public returns (bool) {
        require(balanceOf[msg.sender] >= value, "Insufficient balance");
        balanceOf[msg.sender] -= value;
        balanceOf[to] += value;
        emit Transfer(msg.sender, to, value);
        return true;
    }
    
    function approve(address spender, uint256 value) public returns (bool) {
        allowance[msg.sender][spender] = value;
        emit Approval(msg.sender, spender, value);
        return true;
    }
    
    function transferFrom(address from, address to, uint256 value) public returns (bool) {
        require(balanceOf[from] >= value, "Insufficient balance");
        require(allowance[from][msg.sender] >= value, "Insufficient allowance");
        balanceOf[from] -= value;
        balanceOf[to] += value;
        allowance[from][msg.sender] -= value;
        emit Transfer(from, to, value);
        return true;
    }
}
"""

    try:
        # Initialize the analyzer (replace with your actual model path)
        model_path = "checkpoints_v1_2048_output/best_model_epoch_70.pt"  # Update this path
        analyzer = SmartContractAnalyzer(model_path)
        
        print("🚀 Smart Contract Vulnerability Detection Model Loaded Successfully!")
        print("=" * 90)
        
        # Example 1: Detect vulnerabilities in a contract
        print("\n📋 SC-T-GAN: Smart Contract Transformer GAN - Vulnerability Detection Mode")
        print("-" * 80)
        
        result = analyzer.detect_vulnerabilities(example_contract, threshold=0.3)
        print_vulnerability_report(result, "VulnerableContract")
        
        # Example 2: Generate synthetic contracts
        print("\n\n🔧 EXAMPLE 2: Synthetic Contract Generation")
        print("-" * 40)
        
        generated_contracts = analyzer.generate_synthetic_contract(
            contract_template, 
            num_contracts=2, 
            temperature=0.7
        )
        
        for i, contract in enumerate(generated_contracts):
            print(f"\nGenerated Contract {i+1}:")
            print("-" * 30)
            print(contract[:500] + "..." if len(contract) > 500 else contract)
        
        # Example 3: Analyze the generated contracts
        print("\n\n🔍 EXAMPLE 3: Analyzing Generated Contracts")
        print("-" * 40)
        
        for i, contract in enumerate(generated_contracts):
            print(f"\nAnalyzing Generated Contract {i+1}...")
            analysis = analyzer.detect_vulnerabilities(contract, threshold=0.3)
            print_vulnerability_report(analysis, f"GeneratedContract_{i+1}")
        
        # Example 4: Batch analysis
        print("\n\n📊 EXAMPLE 4: Batch Analysis")
        print("-" * 40)
        
        contracts_to_analyze = [
            example_contract,
            contract_template,
            generated_contracts[0] if generated_contracts else contract_template
        ]
        
        batch_results = analyzer.batch_analyze_contracts(contracts_to_analyze, threshold=0.3)
        
        for i, result in enumerate(batch_results):
            print(f"\nBatch Analysis Result {i+1}:")
            summary = result['summary']
            print(f"  Total vulnerable lines: {summary['total_vulnerable_lines']}")
            print(f"  Vulnerability types: {', '.join(summary['vulnerability_types_found']) if summary['vulnerability_types_found'] else 'None'}")
        
        print("\n✅ All examples completed successfully!")
        
    except FileNotFoundError:
        print("❌ Error: Model file not found!")
        print("Please make sure the model path is correct and the model has been trained.")
        print("Expected path: checkpoints/best_model.pt")
        
    except Exception as e:
        print(f"❌ Error: {str(e)}")
        print("Please check that:")
        print("1. The model has been trained successfully")
        print("2. All required dependencies are installed")
        print("3. The model path is correct")


In [4]:
main() 

Using device: cuda
Model loaded from checkpoints_v1_2048_output/best_model_epoch_70.pt
Training epoch: 70
Best validation loss: 10.655499279499054
🚀 Smart Contract Vulnerability Detection Model Loaded Successfully!

📋 SC-T-GAN: Smart Contract Transformer GAN - Vulnerability Detection Mode
--------------------------------------------------------------------------------

VULNERABILITY ANALYSIS REPORT: VulnerableContract

📋 CONTRACT-LEVEL VULNERABILITIES:
----------------------------------------
ARTHM        | 🔴 VULNERABLE    | Probability: 0.929
DOS          | 🔴 VULNERABLE    | Probability: 0.345
LE           | 🟢 SAFE          | Probability: 0.062
RENT         | 🔴 VULNERABLE    | Probability: 0.589
TimeM        | 🟢 SAFE          | Probability: 0.099
TimeO        | 🔴 VULNERABLE    | Probability: 0.337
Tx-Origin    | 🟢 SAFE          | Probability: 0.000
UE           | 🟢 SAFE          | Probability: 0.270

📄 LINE-LEVEL VULNERABILITIES:
----------------------------------------

📊 SUMMARY:
--