In [None]:
# %pip install cbor2
# %pip install web3
# %pip install ipywidgets
# %pip install jupyterlab
# %pip install jupyter_contrib_nbextensions

In [None]:
# !jupyter nbextension enable --py widgetsnbextension
# !jupyter labextension install @jupyter-widgets/jupyterlab-manager

In [None]:
import binascii
import cbor2
import json
from web3 import Web3
from dotenv import load_dotenv
import os
from web3.contract import Contract
from pathlib import Path
import pprint
import ipywidgets as widgets
from IPython.display import display
from IPython import get_ipython

# load .env file
load_dotenv()

In [None]:
diamondLoupeABI = [
    {
        "constant": True,
        "inputs": [],
        "name": "facetAddresses",
        "outputs": [
            {
                "internalType": "address[]",
                "name": "facetAddresses_",
                "type": "address[]"
            }
        ],
        "payable": False,
        "stateMutability": "view",
        "type": "function"
    }
]

In [None]:
class CompareBytecode:
    def __init__(self, json_with_diamond_address: str, key_to_read_diamond_address: str, 
                 w3: Web3, contract_abi: dict, dir_path_to_broadcasts: str = 'broadcast', dir_path_to_artifacts: str = 'forge-artifacts',
                 additional_contracts: dict = None, compare_with_metadata: bool = False):
        self.json_with_diamond_address = json_with_diamond_address
        self.key_to_read_diamond_address = key_to_read_diamond_address
        self.w3 = w3
        self.contract_abi = contract_abi
        self.dir_path_to_broadcasts = dir_path_to_broadcasts
        self.dir_path_to_artifacts = dir_path_to_artifacts
        self.compare_with_metadata = compare_with_metadata
        self.diamond_address = self.get_address_from_json(self.key_to_read_diamond_address, self.json_with_diamond_address)
        self.contracts = self.create_contract_dict(self.diamond_address) if self.diamond_address else {}
        # Add any additional contracts
        if additional_contracts:
            for contract_name, address in additional_contracts.items():
                if contract_name not in self.contracts:
                    self.contracts[contract_name] = address
                else:
                    print(f"Contract '{contract_name}' already exists in the dictionary.")
        self.bytecodes_without_metadata = self.get_bytecodes_without_metadata(self.contracts)
        self.onchain_bytecodes_without_metadata = self.get_onchain_bytecodes_without_metadata(self.contracts)
        self.comparison_result = self.compare_bytecodes(self.bytecodes_without_metadata, self.onchain_bytecodes_without_metadata)

    def get_address_from_json(self, key=None, filename=None):
        """
        Reads a JSON file and returns the value corresponding to a given key

        Args:
            key (str): key to look up in the JSON file. If None, self.key_to_read_diamond_address is used.
            filename (str): name of the JSON file. If None, self.json_with_diamond_address is used.

        Returns:
            str: value corresponding to the given key or None if the key is not found
        """
        key = self.key_to_read_diamond_address if key is None else key
        filename = self.json_with_diamond_address if filename is None else filename

        if not os.path.isfile(filename):
            print(f"File '{filename}' does not exist.")
            return None

        try:
            with open(filename, 'r') as f:
                data = json.load(f)
        except json.JSONDecodeError:
            print(f"File '{filename}' is not a valid JSON file.")
            return None

        address = data.get(key)
        if address is None:
            print(f"No address found for key '{key}' in {filename}.")

        return address
        
    def is_notebook(self):
        """Check if the script is running in a Jupyter notebook"""
        shell = get_ipython().__class__.__name__
        return shell == 'ZMQInteractiveShell'

    def get_input(self, prompt):
        """Get input from user, compatible with both Jupyter notebook and CLI"""
        if self.is_notebook():
            input_widget = widgets.Text(
                value='',
                placeholder='Enter new address',
                description=prompt,
                disabled=False,
                layout=widgets.Layout(width='auto')  # adjusts the width of the input box
            )
            input_widget.style.description_width = 'initial'  # adjusts the width of the description
            display(input_widget)

            # Wait for input in Jupyter
            while not input_widget.value:
                pass
            return input_widget.value
        else:
            # Get input in CLI
            return input(prompt)


    def create_contract_dict(self, contract_address=None):
        """
        Create a dictionary of contract names and addresses.

        Args:
            contract_address (str): The address of the contract.

        Returns:
            dict: A dictionary with contract names as keys and contract addresses as values.
        """
        contract_dict = {}
        reverse_dict = {}

        contract_address = self.diamond_address if contract_address is None else contract_address

        # Create a contract instance
        contract = self.w3.eth.contract(address=contract_address, abi=self.contract_abi)

        # Call the facetAddresses() method
        try:
            facet_addresses = contract.functions.facetAddresses().call()
            print("Number of facet addresses cut into diamond", len(facet_addresses))
        except:
            print("Failed to call facetAddresses() method")
            return contract_dict

        # Traverse the directory and find all JSON files
        root = self.dir_path_to_broadcasts
        json_files = []
        for dirpath, dirnames, filenames in os.walk(root):
            # check if directory ends with key_to_read_diamond_address
            if dirpath.endswith(self.key_to_read_diamond_address):
                # check if directory is not 'dry-run'
                if 'dry-run' not in dirpath:
                    for filename in filenames:
                        if filename.endswith('.json'):
                            json_path = os.path.join(dirpath, filename)
                            json_files.append(json_path)

        # Sort the files by modification time in reverse order
        sorted_files = sorted(json_files, key=os.path.getmtime, reverse=True)
        for json_path in sorted_files:
            with open(json_path, 'r') as f:
                data = json.load(f)
                transactions = data.get('transactions', [])
                for transaction in transactions:
                    # Check if the contractAddress matches any of the facet addresses
                    if transaction.get('contractAddress') in facet_addresses:
                        contract_name = transaction.get('contractName')
                        contract_address = transaction.get('contractAddress')
                        # If address not in contract_dict or is same as existing key
                        if contract_address not in contract_dict.values() or contract_dict.get(contract_name) == contract_address:
                            contract_dict[contract_name] = contract_address
                            # Update the reverse dictionary
                            if contract_address in reverse_dict and contract_name != reverse_dict[contract_address]:
                                reverse_dict[contract_address].add(contract_name)
                            else:
                                reverse_dict[contract_address] = {contract_name}
                        else:
                            duplicate_keys = [k for k, v in contract_dict.items() if v == contract_address]
                            print(f"Duplicate address {contract_address} found for contracts: {duplicate_keys} and {contract_name}")
                            duplicate_keys.append(contract_name)
                            print("Which contract would you like to provide a new address for?")
                            for i, key in enumerate(duplicate_keys):
                                print(f"{i+1}: {key}")
                            selected_index = int(input("Enter the number corresponding to your selection: ")) - 1
                            selected_key = duplicate_keys[selected_index]
                            new_address = input('Enter new address for the selected contract: ')
                            contract_dict[selected_key] = new_address
        return contract_dict
            
    def get_bytecodes(self, contracts):
        """
        Extracts bytecode from given contracts' JSON files

        Args:
            contracts (dict): dictionary of contract names and addresses.

        Returns:
            dict: dictionary of contract names and their bytecodes
        """
        bytecodes = {}

        for contract_name in contracts.keys():
            json_path = os.path.join(self.dir_path_to_artifacts, f'{contract_name}.sol', f'{contract_name}.json')

            with open(json_path) as f:
                data = json.load(f)

            bytecode = data['deployedBytecode']['object']

            bytecodes[contract_name] = bytecode

        return bytecodes

    def get_onchain_bytecodes(self, contracts):
        """
        Retrieves on-chain runtime bytecodes for given contracts

        Args:
            contracts (dict): dictionary of contract names and addresses.

        Returns:
            dict: dictionary of contract names and their on-chain runtime bytecodes
        """
        onchain_bytecodes = {}

        for contract_name, contract_address in contracts.items():
            onchain_bytecode = self.w3.eth.get_code(contract_address).hex()
            onchain_bytecodes[contract_name] = onchain_bytecode

        return onchain_bytecodes

    def get_metadata(self, bytecode):
        prefix = 'a264'
        start_index = bytecode.find(prefix)

        while start_index != -1:
            try:
                metadata_hex = bytecode[start_index:]
                metadata_bytes = binascii.unhexlify(metadata_hex)
                metadata = cbor2.loads(metadata_bytes)

                return metadata, start_index
            except (ValueError, EOFError):
                start_index = bytecode.find(prefix, start_index + 1)

        return None, -1

    def remove_solidity_metadata(self, bytecode):
        md, idx = self.get_metadata(bytecode)
        if idx != -1:
            return bytecode[:idx]
        return bytecode

    def get_bytecodes_without_metadata(self, contracts):
        """
        Extracts bytecode from given contracts' JSON files and removes the metadata

        Args:
            contracts (dict): dictionary of contract names and addresses.

        Returns:
            dict: dictionary of contract names and their bytecodes without metadata
        """
        bytecodes = self.get_bytecodes(contracts)

        bytecodes_without_metadata = {}

        for contract_name, bytecode in bytecodes.items():
            bytecode_without_metadata = self.remove_solidity_metadata(bytecode)
            bytecodes_without_metadata[contract_name] = bytecode_without_metadata

        return bytecodes_without_metadata

    def get_onchain_bytecodes_without_metadata(self, contracts):
        """
        Retrieves on-chain runtime bytecodes for given contracts and removes the metadata

        Args:
            contracts (dict): dictionary of contract names and addresses.

        Returns:
            dict: dictionary of contract names and their on-chain runtime bytecodes without metadata
        """
        onchain_bytecodes = self.get_onchain_bytecodes(contracts)
        
        onchain_bytecodes_without_metadata = {}

        for contract_name, onchain_bytecode in onchain_bytecodes.items():
            onchain_bytecode_without_metadata = self.remove_solidity_metadata(onchain_bytecode)
            onchain_bytecodes_without_metadata[contract_name] = onchain_bytecode_without_metadata

        return onchain_bytecodes_without_metadata
        
    def are_bytecodes_different(self, bytecode1, bytecode2):
        return bytecode1 != bytecode2

    def compare_bytecodes(self, bytecodes1, bytecodes2):
        """
        Compares bytecodes for each contract in two given dictionaries

        Args:
            bytecodes1 (dict): dictionary of contract names and their bytecodes
            bytecodes2 (dict): dictionary of contract names and their bytecodes

        Returns:
            dict: dictionary of contract names and boolean values indicating if bytecodes are different
        """
        are_different = {}
        for contract_name in bytecodes1.keys():
            are_different[contract_name] = self.are_bytecodes_different(bytecodes1[contract_name], bytecodes2.get(contract_name, ""))

        return are_different
    
    def compare_contract_bytecodes(self):
        # create a dictionary of contract names and addresses
        contracts = self.contracts

        # get bytecodes
        bytecodes = self.get_bytecodes(contracts) if self.compare_with_metadata else self.get_bytecodes_without_metadata(contracts)

        # get onchain bytecodes
        onchain_bytecodes = self.get_onchain_bytecodes(contracts) if self.compare_with_metadata else self.get_onchain_bytecodes_without_metadata(contracts)

        # compare bytecodes
        comparison_result = self.compare_bytecodes(bytecodes, onchain_bytecodes)

        return contracts, comparison_result

In [None]:
json_with_diamond_address = 'deployedAddresses.json'  # update with your json file name
key_to_read_diamond_address = '1'  # update with your key to read proxy address
web3_instance = Web3(Web3.HTTPProvider(os.getenv("ETH_MAINNET_RPC_URL")))  # update with your provider details
contract_abi = diamondLoupeABI
dir_path_to_broadcasts = 'broadcast'  # update with your directory path
dir_path_to_artifacts = 'forge-artifacts'  # update with your directory path
compare_with_metadata = True
additional_contracts = {'OwnershipFacet': '0x073C1a072845D1d87f42309af9911bd3c07fC599', 'DiamondLoupeFacet': '0x0318ff107aFA55E3dc658cEA06748d0c35fbEC73'}
compare_bytecode = CompareBytecode(json_with_diamond_address = json_with_diamond_address, 
                                   key_to_read_diamond_address = key_to_read_diamond_address, 
                                   w3 = web3_instance, 
                                   contract_abi = contract_abi, 
                                   compare_with_metadata = compare_with_metadata,
                                   additional_contracts=additional_contracts)

pprint.pprint(compare_bytecode.compare_contract_bytecodes())