In [11]:
import pickle
from slither.slither import Slither

In [2]:
file = open('function_cpgs.pkl', 'rb')
#file.close()

In [None]:
cpgs = pickle.load(file)

In [4]:
len(cpgs)

In [None]:
cpgs.keys()

In [35]:
cpgs['functions'][0].keys()

In [2]:
!python wl_kernel_embeddings.py --no-normalize

In [None]:
embs = pickle.load(open('function_wl_embeddings.pkl', 'rb'))

In [None]:
embs.keys()

In [None]:
embs['functions'][1].keys()

-- the kernel matrix already returns a matrix of similarities. This can be used directly instead of using another similarity model.

In [None]:
print(embs['functions'][1]['function_full_name'])
embs['functions'][1]['contract']

In [None]:
print(embs['functions'][486]['function_full_name'])
embs['functions'][486]['contract']

In [None]:
embs['embeddings'][1]

In [None]:
embs['embeddings'][1]

In [None]:
import numpy as np

# K is your raw WL kernel matrix, shape (N, N)
def top_k_for(i, K, k=10):
    sims = K[i].copy()
    sims[i] = -10**15          # ignore self
    idx = np.argsort(sims)[::-1][:k]
    return idx, sims[idx]

In [None]:
top_k_for(1, embs['embeddings'], k=10)

In [None]:
embs['functions'][486]['graph']

In [None]:
embs['graphs'][1]

In [3]:
wls = pickle.load(open('function_wl_embeddings.pkl', 'rb'))

In [4]:
wls.keys()

In [5]:
embs = wls['embeddings']

In [1]:
len(embs)

In [8]:
import pickle
import json
import hashlib
import networkx as nx
import random
import os
from collections import Counter, defaultdict
from pathlib import Path
from tqdm import tqdm
import gen_function_ops

# --- IMPORTS FOR LABELING ---
try:
    from slither.slither import Slither
    from slither.detectors import all_detectors
except ImportError:
    print("CRITICAL ERROR: Slither is not installed.")
    print("Run: pip install slither-analyzer")
    exit(1)

# --- CONFIGURATION ---
CPG_DATASET_PATH = "function_cpgs.pkl"       # Your existing CPG file
RESULTS_FILE = "paper_results_disl.json"     # Where to save findings
LABEL_CACHE_FILE = "slither_label_cache.json" # To save time on restarts

# WL Kernel Settings
WL_ITERATIONS = 2
SIMILARITY_THRESHOLD = 0.85

# Gas Detectors to Flag "Bad" Functions
# If a function triggers ANY of these, it is labeled "Inefficient" (0)
GAS_PATTERNS = [
    'external-function',        # Should be external
    'const-state-vars',         # Should be constant
    'immutable-states',         # Should be immutable
    'dead-code',                # Unused code
    'solc-version',             # Old compiler version (usually gas heavy)
    'unchecked-lowlevel',       # Low level calls
    'shadowing-state',          # Shadowing state variables
    'calls-loop',               # External calls inside loops
    'costly-loop'               # Expensive operations in loops
]

# ==========================================
# PART 1: AUTO-LABELING (Slither Integration)
# ==========================================

def analyze_file_gas_costs(file_path):
    """
    Runs Slither on a single Solidity file and returns a dict:
    { "function_name": is_optimized (1 or 0) }
    """
    results = {}
    
    if not os.path.exists(file_path):
        return {}

    try:
        opt_version = gen_function_ops.get_optimal_compiler_version(file_path)
        gen_function_ops.set_solc_version(opt_version)
        # Initialize Slither on the file
        slither = Slither(str(file_path))
        #slither.register_detector_classes(all_detectors)
        
        # Run Detectors
        issues = slither.run_detectors()
        
        # 1. Default all functions to Optimized (1)
        for contract in slither.contracts:
            for function in contract.functions:
                # Store by name. Note: overloading might cause collisions, 
                # but for this scale, name matching is acceptable.
                results[function.name] = 1 
                results[function.full_name] = 1

        # 2. Mark inefficient functions as (0)
        for issue in issues:
            check_name = issue['check']
            if check_name in GAS_PATTERNS:
                for element in issue['elements']:
                    if element['type'] == 'function':
                        # Mark this function as inefficient
                        f_name = element['name']
                        print("f_name", f_name)
                        results[f_name] = 0
                        # Try to capture full name if available in source mapping
                        # (Slither elements are sometimes just strings)
                        
    except Exception as e:
        print("Error:", e)
        # Compilation errors are common in large datasets
        # We silently skip files that don't compile
        pass

    return results

def get_labels_for_dataset(cpg_functions):
    """
    Iterates through all CPGs, finds unique source files, and runs Slither.
    Uses caching to avoid re-running Slither on 47k files if script restarts.
    """
    print("\n--- PHASE 1: GENERATING GAS LABELS ---")
    
    # 1. Identify Unique Files
    unique_files = set()
    for entry in cpg_functions:
        unique_files.add(entry['contract'])
    
    print(f"Found {len(unique_files)} unique source files in CPG dataset.")
    
    # 2. Load Cache if exists
    file_labels = {}
    if os.path.exists(LABEL_CACHE_FILE):
        print("Loading cached labels...")
        with open(LABEL_CACHE_FILE, "r") as f:
            file_labels = json.load(f)
            
    # 3. Process missing files
    files_to_process = [f for f in unique_files if f not in file_labels]
    
    if files_to_process:
        print(f"Running Slither on {len(files_to_process)} new files...")
        for file_path in tqdm(files_to_process):
            # Run Slither
            labels = analyze_file_gas_costs(file_path)
            file_labels[file_path] = labels
            
        # Save cache
        with open(LABEL_CACHE_FILE, "w") as f:
            json.dump(file_labels, f)
            
    return file_labels

# ==========================================
# PART 2: STRUCTURAL ANALYSIS (WL Kernel)
# ==========================================

def cpg_to_networkx(cpg_dict):
    """Converts CPG dict to NetworkX, using structural labels only."""
    G = nx.Graph()
    
    # Add Nodes
    for nid, attrs in cpg_dict['graph']['nodes'].items():
        g_type = str(attrs.get('graph', 'unk'))
        
        # Custom Labeling for CFG Nodes to capture block size (Gas proxy)
        if g_type == 'cfg':
            k_type = str(attrs.get('kind', 'unk'))
            # "irs_count" tells us how many ops are in this block (e.g. empty vs heavy)
            cnt = str(attrs.get('irs_count', '0'))
            sig = f"cfg_{k_type}_{cnt}"
            
        else:
            # Standard labeling for AST and DFG
            t_type = str(attrs.get('type', ''))
            k_type = str(attrs.get('kind', ''))
            sig = f"{g_type}_{t_type}_{k_type}"

        G.add_node(nid, label=sig)
        
    # Add Edges
    for edge in cpg_dict['graph']['edges']:
        G.add_edge(edge['src'], edge['dst'])
        
    return G

"""def wl_hash_graph(G, iterations=2):
    # Computes Weisfeiler-Lehman structural fingerprint.
    # Initial Labels
    current_labels = {n: G.nodes[n]['label'] for n in G.nodes()}
    all_patterns = []
    all_patterns.extend(current_labels.values())
    
    for _ in range(iterations):
        new_labels = {}
        for node in G.nodes():
            neighbors = G.neighbors(node)
            # Get neighbor labels and sort them (crucial for graph isomorphism)
            n_labels = sorted([current_labels[n] for n in neighbors])
            
            # Hash context
            signature = current_labels[node] + "(" + ",".join(n_labels) + ")"
            hashed = hashlib.md5(signature.encode()).hexdigest()
            new_labels[node] = hashed
            
        current_labels = new_labels
        all_patterns.extend(current_labels.values())
        
    return Counter(all_patterns)"""

import hashlib
from collections import Counter
from typing import Any, Callable, Hashable, Optional

def wl_hash_graph(
    G,
    iterations: int = 2,
    node_label_attr: str = "label",
    default_node_label: str = "0",
    directed_neighbors: str = "auto",   # "auto" | "in" | "out" | "both"
    edge_label_attr: Optional[str] = None,
    include_iteration_in_features: bool = True,
    digest_size: int = 16,              # bytes for blake2b digest (small & fast)
) -> Counter:
    """
    Weisfeiler–Lehman (1-WL) subtree fingerprint for a NetworkX graph.

    Returns:
        Counter of WL features. Suitable for similarity via dot-product / cosine.

    Notes:
      - If you want an isomorphism-invariant *hash*, you can hash the final Counter.
      - If you want a WL kernel feature map, this Counter is already that.
    """

    # ---- Helpers ----
    def stable_digest(s: str) -> str:
        # blake2b is fast and stable; digest_size controls length.
        return hashlib.blake2b(s.encode("utf-8"), digest_size=digest_size).hexdigest()

    def get_node_label(n) -> str:
        v = G.nodes[n].get(node_label_attr, default_node_label)
        return str(v)

    def iter_neighbors(n):
        # Directed handling
        if directed_neighbors == "auto":
            is_directed = getattr(G, "is_directed", lambda: False)()
            mode = "both" if is_directed else "out"
        else:
            mode = directed_neighbors

        if mode == "out":
            return G.successors(n) if hasattr(G, "successors") else G.neighbors(n)
        if mode == "in":
            return G.predecessors(n) if hasattr(G, "predecessors") else G.neighbors(n)
        if mode == "both":
            if hasattr(G, "successors") and hasattr(G, "predecessors"):
                # union, but deterministic via sorted later
                return list(G.successors(n)) + list(G.predecessors(n))
            return G.neighbors(n)
        raise ValueError("directed_neighbors must be 'auto', 'in', 'out', or 'both'")

    def edge_label(u, v) -> str:
        if edge_label_attr is None:
            return ""
        data = G.get_edge_data(u, v, default={})
        # MultiDiGraph/MultiGraph: data can be dict-of-dicts
        if isinstance(data, dict) and any(isinstance(val, dict) for val in data.values()):
            # take all parallel edges, collect labels
            labels = []
            for _k, d in data.items():
                labels.append(str(d.get(edge_label_attr, "")))
            labels.sort()
            return "|".join(labels)
        return str(data.get(edge_label_attr, ""))

    # Deterministic node order (stringified to be stable across mixed node types)
    nodes = sorted(G.nodes(), key=lambda x: str(x))

    # ---- Iteration 0 labels ----
    current = {n: get_node_label(n) for n in nodes}

    # Standard WL feature map counts labels per iteration
    feats = Counter()
    for n in nodes:
        key = (0, current[n]) if include_iteration_in_features else current[n]
        feats[key] += 1

    # ---- WL refinement iterations ----
    for it in range(1, iterations + 1):
        # Build signatures
        signatures = {}
        for n in nodes:
            neigh = list(iter_neighbors(n))
            # determinism: sort neighbors by string, then build multiset of (edge_label, neighbor_label)
            parts = []
            for m in sorted(neigh, key=lambda x: str(x)):
                lbl = current.get(m, default_node_label)
                if edge_label_attr is None:
                    parts.append(lbl)
                else:
                    parts.append(f"{edge_label(n, m)}:{lbl}")

            # WL signature: own label + sorted multiset of neighbor contexts
            sig = f"{current[n]}|{','.join(sorted(parts))}"
            signatures[n] = sig

        # Compress signatures to new labels (hash)
        new = {n: stable_digest(signatures[n]) for n in nodes}
        current = new

        # Update feature counts
        for n in nodes:
            key = (it, current[n]) if include_iteration_in_features else current[n]
            feats[key] += 1

    return feats

def jaccard_similarity(vec1, vec2):
    keys1 = set(vec1.keys())
    keys2 = set(vec2.keys())
    intersection = keys1.intersection(keys2)
    union = keys1.union(keys2)
    if not union: return 0.0
    return len(intersection) / len(union)

# ==========================================
# MAIN DRIVER
# ==========================================


# 1. Load CPGs
print(f"Loading CPG dataset: {CPG_DATASET_PATH}...")
try:
    with open(CPG_DATASET_PATH, "rb") as f:
        data = pickle.load(f)
        functions = data['functions']
except FileNotFoundError:
    print("Error: function_cpgs.pkl not found.")
    functions = []

# 2. Get Labels (Integrated Step)
# This runs Slither on the source files referenced in the pickle
file_labels_map = get_labels_for_dataset(functions)

# 3. Build Embeddings
print("\n--- PHASE 2: COMPUTING EMBEDDINGS ---")
database = []

for entry in tqdm(functions):
    # Look up Label
    f_full_name = entry['function_full_name']
    f_path = entry['contract']
    f_name = entry['function_full_name'] # e.g. "transfer"
    f_simple_name = entry['function_simple_name']          # e.g. "transfer(address,uint256)"
    
    # Check if we have a label for this file/function
    """is_optimized = -1 # Unknown
    if f_path in file_labels_map:
        if f_name in file_labels_map[f_path]:
            is_optimized = file_labels_map[f_path][f_name]
        # Fallback: Check full name if simple name failed
        elif entry['function_full_name'] in file_labels_map[f_path]:
            is_optimized = file_labels_map[f_path][entry['function_full_name']]
    
    # Filter: We need labeled data (0 or 1)
    # If Slither failed on the file, we skip it for the paper results
    if is_optimized == -1:
        continue"""
        
    # Filter: Skip tiny graphs (likely getters/setters)
    if entry['num_nodes'] < 5:
        continue

    # Convert and Hash
    G = cpg_to_networkx(entry)
    vec = wl_hash_graph(G, iterations=WL_ITERATIONS)
    
    database.append({
        "full_name": f_full_name,
        "contract": entry['contract'],
        "vector": vec,
        "path": f_path,
        "nodes": entry['num_nodes']
    })
    
print(f"Successfully indexed {len(database)} labeled functions.")

# 4. Find Comparisons
#print("\n--- PHASE 3: FINDING OPTIMIZATIONS ---")
#bad_funcs = [d for d in database if d['is_optimized'] == 0]
#good_funcs = [d for d in database if d['is_optimized'] == 1]

#print(f"Comparison Pool: {len(bad_funcs)} Inefficient vs {len(good_funcs)} Optimized.")

#results = []

# Optimization: Don't compare everything to everything.
# Take a sample of bad functions if too many.
"""search_limit = 200
queries = bad_funcs[:search_limit] if len(bad_funcs) > search_limit else bad_funcs

for bad in tqdm(queries):
    best_match = None
    best_score = 0.0
    
    for good in good_funcs:
        # Heuristic: Don't compare extremely different sizes (optional speedup)
        if abs(bad['nodes'] - good['nodes']) > 15:
            continue
            
        # Don't compare same file
        if bad['path'] == good['path']:
            continue
            
        score = jaccard_similarity(bad['vector'], good['vector'])
        
        if score > SIMILARITY_THRESHOLD and score > best_score:
            best_score = score
            best_match = good
    
    if best_match:
        results.append({
            "inefficient_func": bad['id'],
            "optimized_replacement": best_match['id'],
            "similarity": round(best_score, 4),
            "inefficient_file": bad['path'],
            "optimized_file": best_match['path']
        })
        
# 5. Output
results.sort(key=lambda x: x['similarity'], reverse=True)

print(f"\nFound {len(results)} candidate pairs.")
print(f"Top 5 Matches:")
for res in results[:5]:
    print(f"  [{res['similarity']}] {res['inefficient_func']}  ->  {res['optimized_replacement']}")
    
with open(RESULTS_FILE, "w") as f:
    json.dump(results, f, indent=4)
    
print(f"\nFull results saved to {RESULTS_FILE}")"""

In [5]:
import pickle
cpgs = pickle.load(open('function_cpgs.pkl', 'rb'))

In [6]:
len(cpgs['functions'])

In [5]:
cpgs['functions'][50]['function_simple_name']

In [2]:
len(database)

In [7]:
database[0]

In [7]:
import math

def wl_cosine(phi1: Counter, phi2: Counter) -> float:
    # dot product
    dot = sum(v * phi2.get(k, 0) for k, v in phi1.items())
    n1 = math.sqrt(sum(v * v for v in phi1.values()))
    n2 = math.sqrt(sum(v * v for v in phi2.values()))
    return dot / (n1 * n2 + 1e-12)

In [14]:
sims = []
func = database[158]['vector']
func_name = database[158]['full_name']
func_contract = database[158]['contract']
for other in database:
    if wl_cosine(func, other['vector']) > 0.9:
        sims.append((other['full_name'], other['contract'], wl_cosine(func, other['vector'])))

In [11]:
sims = []
func = database[158]['vector']
func_name = database[158]['full_name']
func_contract = database[158]['contract']
for other in database:
    if jaccard_similarity(func, other['vector']) > 0.9:
        sims.append((other['full_name'], other['contract'], jaccard_similarity(func, other['vector'])))

In [13]:
sims = defaultdict(list)
for i in tqdm(range(len(database))):
    func = database[i]['vector']
    func_name = database[i]['full_name']
    func_contract = database[i]['contract']
    for other in database:
        if wl_cosine(func, other['vector']) > 0.9:
            sims[f"{func_contract}.{func_name}"].append((other['full_name'], other['contract'], wl_cosine(func, other['vector'])))

In [12]:
sims

In [15]:
sims

In [4]:
import pickle
funcs = pickle.load(open('processed_functions.pkl', 'rb'))

In [11]:
funcs_proc = {}
for func in funcs[0]:
    for f_entry in funcs[0][func]:
        if f_entry['name'] not in funcs_proc:
            funcs_proc[f_entry['name']] = []
        funcs_proc[f_entry['name']].append({
            'fn_name': f_entry['name'],
            'contract': f_entry['contract_name'],
            'filename': f_entry['filename'],
            'ops': f_entry['ops'],
            'gas': sum(n for _, n in f_entry['ops']),
        })

In [18]:
from tqdm import tqdm

In [22]:
opt_sims = []
for i in range(len(sims)):
    for sim in sims[list(sims.keys())[i]]:
        filename = sim[2]
        fn_name = sim[0].split('.')[1]
        contract = sim[0].split('.')[0] 
        fns = funcs_proc.get(fn_name, [])
        for fn in fns:
            #if fn['contract'] == contract and fn['filename'] == filename:
            print(fn['contract'], fn['filename'])
            break

In [16]:
opt_sims

In [3]:
import json

In [4]:
sims = json.load(open('paper_results_disl.json', 'r'))

In [9]:
sims[list(sims.keys())[1]]

In [14]:
funcs_proc.keys()

In [2]:
import pickle
import json
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict

INPUT_FILE = "function_cpgs_test.pkl"
OUTPUT_FILE = "gas_estimates_test.json"

# --- CONFIGURATION ---
LOOP_MULTIPLIER = 10  # Assume loops run ~10 times on average for estimation
COSTS = {
    "SSTORE": 20000, "SLOAD": 2100, "EVENT": 375,
    "EXTERNAL_CALL": 2600, "INTERNAL_CALL": 10, "DEPLOY": 32000,
    "EXP": 50, "KECCAK": 30, "BASE": 3
}

def build_ast_parent_map(nodes, edges):
    """
    Creates a lookup dictionary: child_id -> parent_id.
    This allows us to walk UP the tree from any line of code.
    """
    parent_map = {}
    for edge in edges:
        # We only care about the AST hierarchy (how code is nested)
        if edge.get('kind') == 'ast_child':  
            # Edge goes Parent -> Child. We want Child -> Parent.
            parent_map[edge['dst']] = edge['src']
    return parent_map

def is_inside_loop(dfg_node_id, nodes, edges, parent_map):
    """
    Checks if a DFG node is structurally nested inside a Loop statement.
    """
    # 1. Find the AST Anchor
    # DFG nodes are linked to AST nodes via 'maps_to_ast' edges
    ast_anchor_id = None
    for edge in edges:
        if edge['src'] == dfg_node_id and edge.get('kind') == 'maps_to_ast':
            ast_anchor_id = edge['dst']
            break
            
    if not ast_anchor_id:
        return False
        
    # 2. Walk up the AST Tree
    current_id = ast_anchor_id
    while current_id in parent_map:
        parent_id = parent_map[current_id]
        parent_node = nodes[parent_id]
        
        # Check node type (Tree-Sitter types)
        p_type = parent_node.get('type', '')
        if p_type in ['for_statement', 'while_statement', 'do_while_statement']:
            return True
            
        current_id = parent_id
        
    return False

def estimate_gas_with_loops(cpg_data):
    nodes = cpg_data['graph']['nodes']
    edges = cpg_data['graph']['edges']
    
    # Pre-compute structural maps
    parent_map = build_ast_parent_map(nodes, edges)
    
    # Identify State Variables (same as before)
    state_var_ids = set()
    for nid, attrs in nodes.items():
        if attrs.get('kind') == 'state':
            state_var_ids.add(nid)
            if 'label' in attrs: state_var_ids.add(attrs['label'])

    total_score = 0
    
    for nid, attrs in nodes.items():
        if attrs.get('graph') != 'dfg': continue
            
        # --- 1. Determine Base Cost ---
        base_cost = COSTS["BASE"]
        kind = attrs.get('kind', 'Operation')
        label = attrs.get('label', '')
        
        # Check Writes/Reads
        is_storage_op = False
        for def_var in attrs.get('defs', []):
            if def_var in state_var_ids:
                base_cost += COSTS["SSTORE"]
                is_storage_op = True
        for use_var in attrs.get('uses', []):
            if use_var in state_var_ids:
                base_cost += COSTS["SLOAD"]
                is_storage_op = True
                
        # Check Op Type
        if kind in ['HighLevelCall', 'LowLevelCall', 'Transfer', 'Send']:
            base_cost += COSTS["EXTERNAL_CALL"]
        elif kind == 'Event':
            base_cost += COSTS["EVENT"]
        elif kind == 'Binary' and '**' in label:
            base_cost += COSTS["EXP"]
            
        # --- 2. Apply Loop Multiplier ---
        # If this expensive op is inside a loop, we scale it up
        multiplier = 1
        if is_inside_loop(nid, nodes, edges, parent_map):
            multiplier = LOOP_MULTIPLIER
            
        total_score += (base_cost * multiplier)

    return total_score


print(f"Loading CPGs from {INPUT_FILE}...")
try:
    with open(INPUT_FILE, "rb") as f:
        dataset = pickle.load(f)
except FileNotFoundError:
    print("File not found.")

functions = dataset.get("functions", [])
print(f"Estimating Loop-Aware Gas for {len(functions)} functions...")

results = {}
for func in tqdm(functions):
    score = estimate_gas_with_loops(func)
    
    # Helper: Create consistent ID
    from pathlib import Path
    c_name = Path(func['contract']).name
    func_key = f"{c_name}.{func['function_full_name']}"
    
    results[func_key] = score

with open(OUTPUT_FILE, "w") as f:
    json.dump(results, f, indent=4)
print(f"Saved estimates to {OUTPUT_FILE}")

Loading CPGs from function_cpgs_test.pkl...
Estimating Loop-Aware Gas for 706 functions...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 706/706 [00:00<00:00, 10726.62it/s]

Saved estimates to gas_estimates_test.json





In [19]:
gases = json.load(open('gas_estimates.json', 'r'))

In [20]:
len(gases)

In [21]:
gases.keys()

In [29]:
opt_sims = defaultdict(list)
for i in tqdm(range(len(sims))):
    gas_orig = gases.get((list(sims.keys())[i]).split('/')[-1], None)
    for sim in sims[list(sims.keys())[i]]:
        filename = sim[1].split('/')[-1]
        fn_name = sim[0].split('.')[1]
        contract = sim[0].split('.')[0] 
        func_key = f"{filename}.{contract}.{fn_name}"
        gas = gases.get(func_key, None)
        if gas and gas<gas_orig:
            opt_sims[list(sims.keys())[i]].append((func_key, sim[2], gas_orig, gas))

In [31]:
opt_sims.keys()

In [32]:
opt_sims['contracts_organized_disl/ASBOT_92de669a259b5fdc0887adb2f8835398.sol.ASBOT.openTrading()']

In [16]:
sims = json.load(open('paper_results_disl.json', 'r'))

In [25]:
sims["contracts_organized_disl/$$$PEPE_afa8bb971580558ca89d9c36c37caa38.sol.$$$PEPE.constructor()"]

In [1]:
import os
contrs = os.listdir('contracts_disl')
len(contrs)

In [3]:
import json
gas_estimates = json.load(open('gas_estimates_test.json', 'r'))
sim_results = json.load(open('paper_results_disl_test.json', 'r'))

In [2]:
sim_results.keys()

dict_keys(['test_contracts/contract_0106.sol.Address._verifyCallResult(bool,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256)', 'test_contracts/contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256,string)', 'test_contracts/contract_0106.sol.Address.functionDelegateCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionDelegateCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionStaticCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionStaticCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.isContract(address)', 'test_contracts/contract_0106.sol.Address.sendValue(address,uint256)', 'test_contracts/contract_0106.sol.Context._msgData()', 'test_contracts/contract_0106.sol.Context._msgSender

In [4]:
sim_results['test_contracts/contract_0106.sol.ERC721._transfer(address,address,uint256)']

[['PaymentSplitter.release(address)',
  'test_contracts/sample_006.sol',
  0.9309413433074951],
 ['PaymentSplitter._addPayee(address,uint256)',
  'test_contracts/sample_006.sol',
  0.9209734797477722],
 ['ERC721._transfer(address,address,uint256)',
  'test_contracts/sample_006.sol',
  0.995765745639801],
 ['ERC721._mint(address,uint256)',
  'test_contracts/sample_006.sol',
  0.9512621760368347],
 ['ERC721._burn(uint256)',
  'test_contracts/sample_006.sol',
  0.9213229417800903],
 ['OceanWorld._mint(address)',
  'test_contracts/sample_004_changed.sol',
  0.9397016763687134],
 ['ERC721._mint(address,uint256)',
  'test_contracts/sample_004_changed.sol',
  0.9512621760368347],
 ['ERC721._burn(uint256)',
  'test_contracts/sample_004_changed.sol',
  0.9213229417800903],
 ['OceanWorld._mint(address)',
  'test_contracts/sample_004.sol',
  0.9397016763687134],
 ['ERC721._transfer(address,address,uint256)',
  'test_contracts/sample_004.sol',
  1.000002384185791],
 ['ERC721._mint(address,uint256)

In [4]:
test = json.load(open('test_contracts/changes.json', 'r'))

In [5]:
test[0]

{'file': 'sample_000.sol',
 'contract': 'GnosisSafeProxy',
 'function': 'constructor(address)',
 'key': 'test_contracts/sample_000.sol.GnosisSafeProxy.constructor(address)'}

In [5]:
count = 0
for sim in sim_results:
    file = sim.split('/')[1].split('.')[0]+'.sol'
    contract = sim.split('/')[1].split('.')[2]
    function = sim.split('/')[1].split('.')[3]
    for match in sim_results[sim]:
        match_file = match[1].split('/')[-1]
        match_contract = match[0].split('.')[0]
        match_function = match[0].split('.')[1]
        if function == match_function and contract == match_contract:
            count += 1
            break

In [6]:
count

639

In [13]:
sim_results.keys()

dict_keys(['test_contracts/contract_0106.sol.Address._verifyCallResult(bool,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256)', 'test_contracts/contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256,string)', 'test_contracts/contract_0106.sol.Address.functionDelegateCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionDelegateCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.functionStaticCall(address,bytes)', 'test_contracts/contract_0106.sol.Address.functionStaticCall(address,bytes,string)', 'test_contracts/contract_0106.sol.Address.isContract(address)', 'test_contracts/contract_0106.sol.Address.sendValue(address,uint256)', 'test_contracts/contract_0106.sol.Context._msgData()', 'test_contracts/contract_0106.sol.Context._msgSender

In [20]:
sim_results['test_contracts/contract_0107_changed.sol.ERC721.tokenURI(uint256)']

[['ERC721.tokenURI(uint256)',
  'test_contracts/sample_006.sol',
  0.9999979734420776],
 ['ERC721.tokenURI(uint256)',
  'test_contracts/contract_0107.sol',
  0.9999979734420776]]

In [78]:
list(gas_data.keys())[0]

NameError: name 'gas_data' is not defined

In [4]:
list(results.keys())[0]

'contracts_organized_disl/$$$PEPE_afa8bb971580558ca89d9c36c37caa38.sol.$$$PEPE.constructor()'

In [None]:
[['Address.functionCallWithValue(address,bytes,uint256,string)',
  'test_contracts/contract_0106_changed.sol',
  0.9593871831893921],
  ['ERC721._isApprovedOrOwner(address,uint256)',
  'test_contracts/contract_0106_changed.sol',
  0.9230046272277832],
  ['ERC721._transfer(address,address,uint256)',
  'test_contracts/contract_0106_changed.sol',
  0.9707795977592468],
  ['MerkleProof.verify(bytes32[],bytes32,bytes32)',
  'test_contracts/contract_0106_changed.sol',
  0.9154568910598755],
  ['Strings.toHexString(uint256,uint256)',
  'test_contracts/contract_0106_changed.sol',
  0.9570364952087402],
  ['Address.verifyCallResult(bool,bytes,string)',
  'test_contracts/contract_0107_changed.sol',
  0.9370400309562683],
  ['ERC721._checkOnERC721Received(address,address,uint256,bytes)',
  'test_contracts/contract_0107_changed.sol',
  0.9632788896560669],
  ['MerkleProof.verify(bytes32[],bytes32,bytes32)',
  'test_contracts/sample_004_changed.sol',
  0.915773332118988],
  ['Strings.toHexString(uint256,uint256)',
  'test_contracts/sample_004_changed.sol',
  0.9134950637817383],
  ['PaymentSplitter._addPayee(address,uint256)',
  'test_contracts/contract_0107_changed.sol',
  0.963034987449646]]

In [9]:
import hashlib
import pickle

def _normalize_code(code):
    # Collapse whitespace to avoid formatting-only differences
    return " ".join(code.split())

def count_dups_by_signature_and_code(pkl_path):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)

    counts = {}
    groups = {}

    for idx, fn in enumerate(data["functions"]):
        sig = fn.get("function_full_name") or fn.get("function_simple_name") or ""
        root = fn.get("ast_root")
        code = ""
        if root and "graph" in fn and "nodes" in fn["graph"]:
            code = fn["graph"]["nodes"].get(root, {}).get("code", "")
        key = (sig, _normalize_code(code))
        h = hashlib.sha256(repr(key).encode("utf-8")).hexdigest()
        counts[h] = counts.get(h, 0) + 1
        groups.setdefault(h, []).append(idx)

    dup_groups = {h: idxs for h, idxs in groups.items() if len(idxs) > 1}
    dup_entries = sum(len(v) for v in dup_groups.values())
    return dup_entries, len(dup_groups), dup_groups

# Example
dup_entries, dup_group_count, dup_groups = count_dups_by_signature_and_code("function_cpgs_test.pkl")
print("duplicate entries:", dup_entries)
print("duplicate groups:", dup_group_count)


duplicate entries: 637
duplicate groups: 129


In [7]:
import json
import pickle
from collections import defaultdict

def normalize_code(code):
    return " ".join(code.split())

def count_identified_duplicate_pairs(pkl_path, sim_results_path):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)
    with open(sim_results_path, "r") as f:
        sim_results = json.load(f)

    records = []
    for idx, fn in enumerate(data["functions"]):
        contract = fn.get("contract", "")
        full_name = fn.get("function_full_name") or fn.get("function_simple_name") or ""
        root = fn.get("ast_root")
        code = ""
        if root and "graph" in fn and "nodes" in fn["graph"]:
            code = fn["graph"]["nodes"].get(root, {}).get("code", "")
        dup_key = (full_name, normalize_code(code))
        sim_key = f"{contract}.{full_name}" if contract and full_name else None
        records.append(
            {"idx": idx, "contract": contract, "full_name": full_name,
             "dup_key": dup_key, "sim_key": sim_key}
        )

    by_dup = defaultdict(list)
    for r in records:
        by_dup[r["dup_key"]].append(r)

    dup_groups = [g for g in by_dup.values() if len(g) > 1]

    def appears_in_sim(a, b):
        if not a["sim_key"] or a["sim_key"] not in sim_results:
            return False
        for entry in sim_results[a["sim_key"]]:
            if not isinstance(entry, list) or len(entry) < 2:
                continue
            fn_name, contract_path = entry[0], entry[1]
            if fn_name == b["full_name"] and contract_path == b["contract"]:
                return True
        return False

    pairs = set()
    identified = set()

    for group in dup_groups:
        for i in range(len(group)):
            for j in range(i + 1, len(group)):
                a = group[i]   # dict
                b = group[j]   # dict
                pairs.add(frozenset((a["idx"], b["idx"])))
                if appears_in_sim(a, b) or appears_in_sim(b, a):
                    identified.add(frozenset((a["idx"], b["idx"])))


    return len(dup_groups), len(pairs), len(identified)

dup_groups, dup_pairs, identified_pairs = count_identified_duplicate_pairs(
    "function_cpgs_test.pkl", "paper_results_disl_test.json"
)
print("duplicate groups:", dup_groups)
print("duplicate pairs:", dup_pairs)
print("identified duplicate pairs:", identified_pairs)


duplicate groups: 129
duplicate pairs: 1451
identified duplicate pairs: 1354


In [8]:
1354/1451

0.9331495520330806

In [10]:
cpgs = pickle.load(open('function_cpgs_test.pkl', 'rb'))

In [11]:
len(cpgs['functions'])

706

In [1]:
import json
gases = json.load(open('gas_estimates_test.json', 'r'))

In [2]:
gases.keys()

dict_keys(['contract_0106.sol.Address._verifyCallResult(bool,bytes,string)', 'contract_0106.sol.Address.functionCall(address,bytes)', 'contract_0106.sol.Address.functionCall(address,bytes,string)', 'contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256)', 'contract_0106.sol.Address.functionCallWithValue(address,bytes,uint256,string)', 'contract_0106.sol.Address.functionDelegateCall(address,bytes)', 'contract_0106.sol.Address.functionDelegateCall(address,bytes,string)', 'contract_0106.sol.Address.functionStaticCall(address,bytes)', 'contract_0106.sol.Address.functionStaticCall(address,bytes,string)', 'contract_0106.sol.Address.isContract(address)', 'contract_0106.sol.Address.sendValue(address,uint256)', 'contract_0106.sol.Context._msgData()', 'contract_0106.sol.Context._msgSender()', 'contract_0106.sol.Counters.current(Counters.Counter)', 'contract_0106.sol.Counters.decrement(Counters.Counter)', 'contract_0106.sol.Counters.increment(Counters.Counter)', 'contract_0106.sol

In [6]:
gases['contract_0107.sol.ERC721._checkOnERC721Received(address,address,uint256,bytes)']

2687

In [5]:
gases['contract_0107_changed.sol.ERC721._checkOnERC721Received(address,address,uint256,bytes)']

2690