# Finetuning
# Malicious Smart Contract Detection Training Dataset Collection Notebook

This notebook collects smart contract creation bytecode and decompiled opcodes for malicious contract classification. 
Benign contracts are gathered from blockchain explorers and malicious contracts from [Forta Network's labelled datasets github repo]("https://github.com/forta-network/labelled-datasets").

Code provided by the Forta project

In [None]:
import logging
import pickle
import os
import time
import json

from evmdasm import EvmBytecode
import pandas as pd
from tqdm import tqdm
import requests

tqdm.pandas()
# disable warning logs from evmdasm tool
logging.getLogger("evmdasm").setLevel(logging.CRITICAL)

blockchains = ["ethereum", "polygon", "bsc"]
current_blockchain = "ethereum"

ZETTABLOCK_API_KEY = "519a0243-0add-4593-9adf-85c6a6a5ef22"
ZETTABLOCK_URL = "https://api.zettablock.com/api/v1/databases/realtimeDB/queries"
ZETTABLOCK_DATA_LAKE_ENDPOINT = "https://api.zettablock.com/api/v1/databases/AwsDataCatalog/queries"
EXECUTE_URL = "https://api.zettablock.com/api/v1/queries/"

headers = {
    "accept": "application/json",
    "X-API-KEY": ZETTABLOCK_API_KEY,
    "content-type": "application/json"
}

TRACES = {}
CONTRACT_DATA = {}

In [None]:
def get_verified_smart_contracts():
    verified_smart_contracts = {}
    # 5k verified SCs Downloaded from https://etherscan.io/exportData?type=open-source-contract-codes
    verified_smart_contracts["ethereum"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/ethereum-verified.csv")
    # 5k verified SCs Downloaded from https://polygonscan.com/exportData?type=open-source-contract-codes
    verified_smart_contracts["polygon"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/polygon-verified.csv")
    # 5k verified SCs Downloaded from https://bscscan.com/exportData?type=open-source-contract-codes
    verified_smart_contracts["bsc"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/bsc-verified.csv")
    return verified_smart_contracts

def get_verified_malicious_smart_contracts():
    """Collects malicious contracts from Forta's labelled dataset github repo and its decompiled opcodes."""
    # csv from https://github.com/forta-network/labelled-datasets (Only Ethereum)
    # github_url = 'https://raw.githubusercontent.com/forta-network/labelled-datasets/main/labels/1/malicious_smart_contracts.csv'
    # malicious = pd.read_csv(github_url)
    # The other blockchain data apart from Ethereum was collected by Sakundi
    malicious_smart_contracts = {}
    malicious_smart_contracts["ethereum"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/ethereum-malicious.csv")
    malicious_smart_contracts["polygon"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/polygon-malicious.csv")
    malicious_smart_contracts["bsc"] = pd.read_csv("/data/forta/ethereum/text/pretraining/raw/verified/bsc-malicious.csv")
    return malicious_smart_contracts

In [None]:
def get_contract_bytecode(contract):
    try:
        query_text = ("SELECT bytecode FROM %s_mainnet.contract_creations WHERE address = '%s' LIMIT 1"
                     % (current_blockchain, contract["contract_address"]))
        contract_bytecode = call_zettablock_api(query_text, current_blockchain)
        if contract_bytecode is not None:
            return contract_bytecode
        else:
            return ""
    except Exception as e:
        print(e)
        return []

# Code taken from Zettablock tutorials
# check response until success or failed is returned
def get_response(queryrun_id):
	import time
	i = 1
	queryrun_status_endpoint = f'https://api.zettablock.com/api/v1/queryruns/{queryrun_id}/status'
	while True:
		res = requests.get(queryrun_status_endpoint, headers=headers)
		state = json.loads(res.text)['state']
		if state == 'SUCCEEDED' or state == 'FAILED':
			return state
		time.sleep(i)
		i += 1

def call_zettablock_api(query_text, blockchain):
    # Get Smart Contract Data from Zettablock for several blockchains
    query = {"query": query_text, "resultCacheExpireMillis": 86400000}
    
    # Create a query with SQL statement, and get query id
    res = requests.post(ZETTABLOCK_DATA_LAKE_ENDPOINT, headers=headers, data=json.dumps(query))    
    # Trigger the query by query id, and get queryrun id
    query_id = res.json()['id']
    data_lake_submission_endpoints = f'https://api.zettablock.com/api/v1/queries/{query_id}/trigger'
    res = requests.post(data_lake_submission_endpoints, headers=headers, data='{}')
    
    # Check status using queryrun id
    queryrun_id = res.json()['queryrunId']
    
    if get_response(queryrun_id) == 'SUCCEEDED':
        # Fetch result from queryrun id
        params = {'includeColumnName': 'true'}
        queryrun_result_endpoint = f'https://api.zettablock.com/api/v1/stream/queryruns/{queryrun_id}/result'
        # if the result is huge, consider using stream and write to a file
        resp = requests.get(queryrun_result_endpoint, stream=False, headers=headers, params=params)
        lines = resp.text.split("\n")
        data = lines[1]
        return data
    else:
        print('query failed, please check status message for details')
        print(res.json())
        return None

In [None]:
def get_opcodes(creation_bytecode) -> str:
    bytecode = creation_bytecode
    if bytecode is None:
        return ''
    try:
        opcodes = EvmBytecode(bytecode).disassemble()
    except Exception:
        return ''
    
    return " ".join([str(op).strip() for op in opcodes])

In [None]:
def get_malicious_contracts() -> pd.DataFrame:
    data_path = '/data/forta/ethereum/text/pretraining/raw/verified/malicious-data.pkl'
    malicious = None

    if os.path.exists(data_path):
        with open(data_path, "rb") as data_file:
            malicious = pickle.load(data_file)
    else:
        malicious = get_verified_malicious_smart_contracts()
        for blockchain in blockchains:
            current_blockchain = blockchain
            if blockchain == "ethereum":
                # exclude phishing hack related contracts
                malicious[blockchain] = malicious[blockchain][malicious[blockchain]['contract_creator_etherscan_label'] != 'phish-hack']
            malicious[blockchain]['contract_address'] = malicious[blockchain]['contract_address'].progress_apply(str.lower)
            malicious[blockchain]['creation_bytecode'] = malicious[blockchain].progress_apply(get_contract_bytecode, axis=1)
        malicious = pd.DataFrame(pd.concat([malicious["ethereum"], malicious["polygon"], malicious["bsc"]]))
        malicious['decompiled_opcodes'] = malicious['creation_bytecode'].progress_apply(get_opcodes)
        # Store data so we don't have to download it all the time
        malicious.to_pickle(data_path)
    return malicious

In [None]:
def get_benign_contracts() -> pd.DataFrame:
    global current_blockchain
    data_path = '/data/forta/ethereum/text/pretraining/raw/verified/benign-data.pkl'
    benign = None

    if os.path.exists(data_path):
        with open(data_path, "rb") as data_file:
            benign = pickle.load(data_file)
    else:
        """Collects verified and mev contracts and its decompiled opcodes."""
        benign = get_verified_smart_contracts()
        for blockchain in blockchains:
            current_blockchain = blockchain
            benign[blockchain]['contract_address'] = benign[blockchain]['contract_address'].progress_apply(str.lower)
            benign[blockchain]['creation_bytecode'] = benign[blockchain].progress_apply(get_contract_bytecode, axis=1)
        begign = pd.DataFrame(pd.concat([benign["ethereum"], benign["polygon"], benign["bsc"]]))
        benign = begign.reset_index(drop=True)
        benign['decompiled_opcodes'] = benign['creation_bytecode'].progress_apply(get_opcodes)
        # Store data so we don't have to download it all the time
        benign.to_pickle(data_path)
    return benign

In [None]:
malicious_contracts = get_malicious_contracts()
malicious_contracts['malicious'] = True

In [None]:
benign_contracts = get_benign_contracts()
benign_contracts['malicious'] = False

In [None]:
dataset = pd.concat([malicious_contracts, benign_contracts])

In [None]:
dataset = dataset[(dataset['decompiled_opcodes'].notna()) & (dataset['decompiled_opcodes'] != '')]
dataset.drop_duplicates('contract_address', inplace=True)

In [None]:
dataset['malicious'].value_counts()

In [None]:
dataset.fillna('').to_parquet('/data/forta/ethereum/text/pretraining/raw/verified/verified-smart-contracts.parquet', index=None)

In [None]:
dataset['malicious'].value_counts().plot(kind='pie', figsize=(7, 7))