# Finetuning
# Malicious Smart Contract Detection Training Dataset Collection Notebook

This notebook collects smart contract creation bytecode and decompiled opcodes for malicious contract classification. 
Benign contracts are gathered from Zettablock and malicious contracts from [Forta Network's labelled datasets github repo]("https://github.com/forta-network/labelled-datasets"). 

In [None]:
import logging
import pickle
import os

from evmdasm import EvmBytecode
import pandas as pd
from tqdm import tqdm
import requests
from web3 import Web3

tqdm.pandas()
# disable warning logs from evmdasm tool
logging.getLogger("evmdasm").setLevel(logging.CRITICAL)

In [None]:
def get_verified_smart_contracts():
    # 5k verified SCs Downloaded from https://etherscan.io/exportData?type=open-source-contract-codes
    verified_smart_contracts = pd.read_csv("/data/forta/ethereum/text/normal_smart_contracts.csv")
    return verified_smart_contracts

In [None]:
ALCHEMY_URL = "https://eth-mainnet.g.alchemy.com/v2/f8hEqd_nKEpamacB-zXeWPv7R9QmWKPx"
w3_eth = Web3(Web3.HTTPProvider(ALCHEMY_URL))
TRACES = {}

def get_contract_bytecode(contract):
    try:
        contract_address = Web3.to_checksum_address(contract["contract_address"])
        contract_bytecode = w3_eth.eth.get_code(contract_address)
        return contract_bytecode
    except Exception as e:
        print(e)
        return []

In [None]:
def get_created_contracts(tx_hash):
    if TRACES.get(tx_hash, 'error') != 'error':
        return TRACES[tx_hash]
    try:
        trace = w3_eth.tracing.trace_transaction(tx_hash)
        result = [t for t in trace if t['type'].startswith('create')]
        TRACES[tx_hash] = result
        return result
    except Exception as e:
        print(e)
        TRACES[tx_hash] = 'error'
        return []    

In [None]:
def get_creation_bytecode(row) -> str:
    """Get contract creation bytecode from EVM trace."""
    tx_hash = row['contract_creation_tx']
    contracts = get_created_contracts(tx_hash)
    
    if len(contracts) == 0:
        return

    for c in contracts:
        if c['result'] is not None and c['result'].get('address') == contract_addr:
            return c['action'].get('init')

def get_opcodes(creation_bytecode) -> str:
    bytecode = creation_bytecode
    if bytecode is None:
        return ''

    try:
        opcodes = EvmBytecode(bytecode).disassemble()
    except Exception:
        return ''
    
    return " ".join([str(op).strip() for op in opcodes])

In [None]:
def get_malicious_contracts() -> pd.DataFrame:
    data_path = '/data/forta/ethereum/text/malicious_data.pkl'
    malicious = None

    if os.path.exists(data_path):
        with open(data_path, "rb") as data_file:
            malicious = pickle.load(data_file)
    else:
        """Collects malicious contracts from Forta's labelled dataset github repo and its decompiled opcodes."""
        # csv from https://github.com/forta-network/labelled-datasets
        github_url = 'https://raw.githubusercontent.com/forta-network/labelled-datasets/main/labels/1/malicious_smart_contracts.csv'
        malicious = pd.read_csv(github_url)
        # exclude phishing hack related contracts
        malicious = malicious[malicious['contract_creator_etherscan_label'] != 'phish-hack']
        malicious['creation_bytecode'] = malicious.progress_apply(get_contract_bytecode, axis=1)
        malicious['decompiled_opcodes'] = malicious['creation_bytecode'].progress_apply(get_opcodes)
        # Store data so we don't have to download it all the time
        malicious.to_pickle(data_path)
    return malicious

In [None]:
CONTRACT_DATA = {}
ETHERSCAN_API_KEY="SX7737W8M9DDTYHSXSV8XG8G945PMC9D4U"

def get_contract_transactions(contracts):
    """Get contract transaction info from Etherscan."""
    
    # Etherscan API can take up to 5 contract addresses at a time.
    for i in range(0, len(contracts), 5):
        url = f"https://api.etherscan.io/api?module=contract&action=getcontractcreation&contractaddresses={','.join(contracts[i:i+5])}&apikey={ETHERSCAN_API_KEY}"
        resp = requests.get(url)
        data = resp.json()['result']
        for r in data:
            contract = r['contractAddress']
            CONTRACT_DATA[contract.lower()] = r

In [None]:
def get_contract_data(row):
    address = row['contract_address'].lower()
    data = CONTRACT_DATA.get(address)
    if data:
        return data.get('contractCreator'), data.get('txHash')
    
    return None, None

In [None]:
verified_contracts_sql = '''
SELECT abis.address as contract_address, 
       abis.name as contract_name,
       tags.name as contract_etherscan_label, 
       tags.type as contract_tag
FROM ethereum_mainnet.contracts abis LEFT JOIN ethereum_mainnet.labels tags ON tags.address = abis.address
'''

mev_contracts_sql = '''
SELECT tags.address as contract_address, 
       tags.type as label_type, 
       tags.name as label_name 
FROM ethereum_mainnet.labels tags WHERE tags.name like 'mev%'
'''

def get_benign_contracts() -> pd.DataFrame:
    data_path = '/data/forta/ethereum/text/benign_data.pkl'
    benign = None

    if os.path.exists(data_path):
        with open(data_path, "rb") as data_file:
            benign = pickle.load(data_file)
    else:
        """Collects verified and mev contracts from Zettablock and its decompiled opcodes."""

        # verified_contracts = get_zettablock_data(verified_contracts_sql)
        # mev_contracts = get_luabase_data(mev_contracts_sql)
        # benign = pd.concat([verified_contracts, mev_contracts])
        benign = get_verified_smart_contracts()

        get_contract_transactions(list(benign.loc[:, 'contract_address']))

        benign[['contract_creator', 'contract_creation_tx']] = benign.apply(get_contract_data, axis=1, result_type='expand')
        benign['creation_bytecode'] = benign.progress_apply(get_contract_bytecode, axis=1)
        benign['decompiled_opcodes'] = benign['creation_bytecode'].progress_apply(get_opcodes)
        # Store data so we don't have to download it all the time
        benign.to_pickle('/data/forta/ethereum/text/benign_data.pkl')
    return benign

In [None]:
malicious_contracts = get_malicious_contracts()
malicious_contracts['malicious'] = True

In [None]:
benign_contracts = get_benign_contracts()
benign_contracts['malicious'] = False

In [None]:
dataset = pd.concat([malicious_contracts, benign_contracts])

In [None]:
dataset = dataset[(dataset['decompiled_opcodes'].notna()) & (dataset['decompiled_opcodes'] != '')]
dataset.drop_duplicates('contract_address', inplace=True)

In [None]:
dataset['malicious'].value_counts()

In [None]:
dataset.fillna('').to_parquet('/data/forta/ethereum/text/malicious_contract_training_dataset_final.parquet', index=None)

In [None]:
dataset['malicious'].value_counts().plot(kind='pie', figsize=(7, 7))