In [123]:
import json
import os
import random
import subprocess

from anthropic import Anthropic
from dotenv import load_dotenv, find_dotenv
from IPython.display import display, Markdown
from openai import OpenAI
from functions import *

from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Literal, List

from langchain_openai import ChatOpenAI
from langchain.schema.runnable import Runnable
from typing import Dict, Any, Optional

from langchain.chains import create_structured_output_runnable
from langchain_anthropic import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser

In [100]:
query = "I want a hook that has a counter every time someone makes swap or by adding or removing liquidity"

find_relevant_hooks = """You are an expert at routing a user question to the appropriate data source.
Based on the intented hooks, choose the most relevant datasources."""

In [101]:
with open("hook_examples.json", 'r') as file:
    hook_examples_json = json.load(file)

DatasourceOptions = Literal[*tuple(hook_examples_json.keys())]

class FindRelevantHooks(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasources: List[DatasourceOptions] = Field(
        ...,
        description=f"Given a user question choose which datasources would be most relevant for answering their question. Available datasources: {json.dumps(hook_examples_json, indent=2)}",
    )

In [102]:
llm_choice = "anthropic"

if llm_choice == "openai":
    llm = ChatOpenAI(temperature=0, model="gpt-4")

    find_relevant_hooks_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", find_relevant_hooks),
            ("human", "{question}"),
        ]
    )

    chain_one = create_structured_output_runnable(
        FindRelevantHooks,
        llm=llm, 
        prompt=find_relevant_hooks_prompt, 
    )
elif llm_choice == "anthropic":
    llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=0)

    find_relevant_hooks_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", find_relevant_hooks),
            ("human", "{question}\n\nPlease provide your response in the following JSON format:\n{format_instructions}"),
        ]
    )

    # Create a PydanticOutputParser
    parser = PydanticOutputParser(pydantic_object=FindRelevantHooks)

    # Modify the prompt to include format instructions
    prompt = find_relevant_hooks_prompt.partial(format_instructions=parser.get_format_instructions())

    # Create the chain
    chain_one = prompt | llm | parser

In [103]:
def read_hooks(x):
    print("reading hooks", x)
    folder_path = '../foundry_hook_playground/src/examples/src/'
    hooks = ""
    for i, hook in enumerate(x):
        link_words = f"\n\n example {i}:\n -------------------------\n\n"
        hooks += link_words
        with open(folder_path + hook, 'r') as file:
            contents = file.read()
            hooks += contents
    return {"hooks": hooks}

prompt_instructions = read_file('instructions_contract_generation.txt')

chain_two_template = ChatPromptTemplate.from_messages([
    ("system", prompt_instructions + "\n\n" + 
     "Here are some relevant hook examples:\n{hooks}\n\n"
    ),
    ("human", 
     "Based on these instructions, generate a Solidity contract for the following question: {question}\n\n" + 
     "I only want you to generate the solidity code, nothing else. Write directly the code.\n" + 
     "You will name the contract as GeneratedContract.sol\n"
    )
])

chain_two = chain_two_template | llm

# Print the entire prompt
sample_inputs = {
    "question": "[Question]",
    "hooks": "[Hooks]"
}

formatted_prompt = chain_two_template.format_messages(**sample_inputs)

print("Full Prompt:")
for message in formatted_prompt:
    print(f"Role: {message.type}")
    print(f"Content: {message.content}")
    print("---")

Full Prompt:
Role: system
Content: You are an expert solidity developer. Your task is to assist in creating hook smart contracts for Uniswap V4. 

Uniswap v4 Hooks -- also known simply as hooks -- are specially designed contracts that run at distinct points throughout a pool action's lifecycle. They serve as plugins allowing developers to tailor how pools, swaps, fees, and LP positions interact. This enables innovation atop Uniswap v4's core features, thereby supporting the development of custom AMM pools.

During the course of a pool action's lifecycle, a hook invokes custom logic primarily at four critical phases:

Initialize: Activated when the pool is deployed.
Modify Position: Used to add or remove liquidity.
Swap: Engages a swap between tokens within the V4 ecosystem.
Donate: Facilitates the donation of liquidity to a V4 pool. Upon initialization, a pool can be associated with a hook contract. Such a contract has the ability to execute any of the callback functions during the poo

In [104]:
def fix_contract(contract_code, error_message):
    fix_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an expert Solidity developer. Fix the compilation error in the following smart contract."),
        ("human", """
         Here's the smart contract with a compilation error:\n\n{contract_code}\n\nError message:\n{error_message}\n\nPlease provide the corrected contract.
         I only want you to fix the contract, nothing else. Write directly the corrected contract.
        """)
    ])
    
    fix_chain = fix_prompt | ChatOpenAI(temperature=0.7)
    fixed_contract = fix_chain.invoke({"contract_code": contract_code, "error_message": error_message})
    
    # Save fixed contract to a file
    with open('../foundry_hook_playground/src/generated/GeneratedContract.sol', 'w') as file:
        file.write(fixed_contract.content)
    
    return fixed_contract.content

In [105]:
def extract_and_save_code(x, path):
    # Try to extract code enclosed in triple backticks
    match = re.search(r'```solidity\n(.*?)```', x, re.DOTALL)
    if not match:
        # If not found, try to extract code enclosed in triple single quotes
        match = re.search(r"'''solidity\n(.*?)'''", x, re.DOTALL)
    
    if match:
        code = match.group(1)
    else:
        # If no specific code block is found, use the entire content
        code = x

    # Remove any leading or trailing whitespace
    code = code.strip()

    # Save the code to a file
    with open(path, 'w') as file:
        file.write(code)

    return code  # Return the code for printing

class CompilationFailedException(Exception):
    pass

class CompileAndFixContract(Runnable):
    def __init__(self, max_attempts=3):
        self.max_attempts = max_attempts

    def invoke(self, inputs: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:

        contract = inputs['contract']
        stdout, stderr, returncode = compile_contract('GeneratedContract.sol')
        if returncode == 0:
            print("\U0001F389 \U0001F389 Contract compiled successfully! \U0001F389 \U0001F389")
            return {"contract": contract, "compiled": True, "stderr": None}
            
        attempt = 0
        while attempt < self.max_attempts:
            attempt += 1
            print("attempt", attempt)

            contract = fix_contract(contract, stderr)
            contract = extract_and_save_code(contract)
            
            stdout, stderr, returncode = compile_contract('GeneratedContract.sol')
            if returncode == 0:
                print("\U0001F389 \U0001F389 Contract compiled successfully! \U0001F389 \U0001F389")
                return {"contract": contract, "compiled": True, "stderr": None}
        
        print("could not compile the contract")
        return {"contract": contract, "compiled": False, "stderr": stderr}


In [117]:
prompt_instructions = read_file('instructions_contract_generation.txt')
unit_test_instructions = read_file('instructions_test_contract_generation.txt')

chain_three_template = ChatPromptTemplate.from_messages([
    ("system", prompt_instructions + "\n\n" + 
     "Here are some relevant hook examples:\n{test_hooks}\n\n" +
     "{test_instructions}"
    ),
    ("human", 
     "Based on these instructions, generate a the unit this of this Solidity contract: \n\n {contract_code} \n\n" +
     "I only want you to generate the solidity code, nothing else. Write directly the code.\n" + 
     "You will name the contract as GeneratedContractTest.t.sol\n"
    )
])

chain_three = chain_three_template | llm

In [126]:
class CompileAndFixTestContract(Runnable):
    def __init__(self, max_attempts=3):
        self.max_attempts = max_attempts

    def invoke(self, inputs: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        test_contract = inputs['test_contract']
        contract_code = inputs['contract_code']

        stdout, stderr, returncode = compile_test_contract('GeneratedContractTest.t.sol')
        if returncode == 0:
            print("\U0001F389 \U0001F389 Test contract compiled successfully! \U0001F389 \U0001F389")
            return {"test_contract": test_contract, "compiled": True, "stderr": None}
            
        attempt = 0
        while attempt < self.max_attempts:
            attempt += 1
            print("Test contract compilation attempt", attempt)

            # Fix the test contract
            fix_prompt = ChatPromptTemplate.from_messages([
                ("system", "You are an expert Solidity developer. Fix the compilation error in the following test contract."),
                ("human", """
                 Here's the test contract with a compilation error:\n\n{test_contract}\n\n
                 Error message:\n{error_message}\n\n
                 Original contract being tested:\n{contract_code}\n\n
                 Please provide the corrected test contract.
                 I only want you to fix the contract, nothing else. Write directly the corrected contract.
                """)
            ])
            
            fix_chain = fix_prompt | llm
            fixed_test_contract = fix_chain.invoke({
                "test_contract": test_contract, 
                "error_message": stderr,
                "contract_code": contract_code
            }).content

            # Extract and save the fixed test contract
            test_contract = extract_and_save_code(fixed_test_contract, "../foundry_hook_playground/src/test/GeneratedContractTest.t.sol")
            
            stdout, stderr, returncode = compile_test_contract('GeneratedContractTest.t.sol')
            if returncode == 0:
                print("\U0001F389 \U0001F389 Test contract compiled successfully! \U0001F389 \U0001F389")
                return {"test_contract": test_contract, "compiled": True, "stderr": None}
        
        print("Could not compile the test contract")
        return {"test_contract": test_contract, "compiled": False, "stderr": stderr}


In [119]:
def read_test_hooks(x):
    print("reading test hooks", x)
    folder_path = '../foundry_hook_playground/src/examples/test/'
    hooks = ""
    
    if not x:  # Check if x is empty
        print("No hooks provided.")
        return {"test_hooks": ""}
    
    for i, hook in enumerate(x):
        try:
            file_path = folder_path + hook[:-3] + 't.sol'
            if os.path.exists(file_path):
                link_words = f"\n\n example {i}:\n -------------------------\n\n"
                hooks += link_words
                with open(file_path, 'r') as file:
                    contents = file.read()
                    hooks += contents
            else:
                print(f"Warning: File not found - {file_path}")
        except Exception as e:
            print(f"Error reading hook {hook}: {str(e)}")
    
    return {"test_hooks": hooks}

In [127]:
def overall_chain(question, test_instructions):
    # Step 1: Find relevant hooks
    relevant_hooks = chain_one.invoke({"question": question}).datasources

    # Step 2: Read hooks
    hooks_data = read_hooks(relevant_hooks)

    # Step 3: Generate Solidity contract
    generated_contract = chain_two.invoke({"question": question, "hooks": hooks_data['hooks']}).content

    # Step 4: Extract and save code
    contract_code = extract_and_save_code(generated_contract, "../foundry_hook_playground/src/generated/GeneratedContract.sol")
    
    # Step 5: Compile and fix contract
    compile_and_fix = CompileAndFixContract(max_attempts=1)
    compile_result = compile_and_fix.invoke({"contract": contract_code})
    
    # Step 6: Read test hooks
    test_hooks_data = read_test_hooks(relevant_hooks)
    
    # Step 8: Generate unit tests if compiled
    if compile_result['compiled']:
        generated_unit_tests = chain_three.invoke({
            "contract_code": contract_code, 
            "test_hooks": test_hooks_data['test_hooks'], 
            "test_instructions": test_instructions
        }).content
        unit_tests = extract_and_save_code(generated_unit_tests, "../foundry_hook_playground/src/test/GeneratedContractTest.t.sol")
        compile_and_fix_test = CompileAndFixTestContract(max_attempts=3)
        compile_result_test = compile_and_fix_test.invoke({"test_contract": unit_tests, "contract_code": contract_code})


    return {
        "relevant_hooks": relevant_hooks,
        "contract": contract_code,
        "unit_tests": unit_tests,
        "compiled": compile_result['compiled'],
        "compiled_test": compile_result_test['compiled']
    }

# Use the overall chain
result = overall_chain(query, unit_test_instructions)

# Use the chain
try:
    print("Selected Hooks:")
    print(result['relevant_hooks'])
    
    print("\nCompilation Result:")
    print(result['compiled'])

    print("\nCompilation Result Test:")
    print(result['compiled_test'])
    
    # print("\nGenerated Solidity Code:")
    # print(result['contract'])
    
    # print("\nGenerated Unit Tests:")
    # print(result['unit_tests'])

except CompilationFailedException as e:
    print("Compilation failed. Chain execution stopped.")
    print(f"Error: {str(e)}")

except Exception as e:
    print(f"An unexpected error occurred: {str(e)}")

reading hooks ['Counter.sol', 'PointsHook.sol']
🎉 🎉 Contract compiled successfully! 🎉 🎉
reading test hooks ['Counter.sol', 'PointsHook.sol']
Test contract compilation attempt 1
Could not compile the test contract
Selected Hooks:
['Counter.sol', 'PointsHook.sol']

Compilation Result:
True

Compilation Result Test:
False


In [None]:
stdout, stderr, returncode = compile_contract('GeneratedContract.sol')

if returncode == 0:
    print("\U0001F389 \U0001F389 Contract compiled successfully! \U0001F389 \U0001F389")
else: print("did not compile")

🎉 🎉 Contract compiled successfully! 🎉 🎉


In [122]:
def compile_test_contract():
    command = f"forge test"
    working_directory = "../foundry_hook_playground/src/test"

    # Run the command in the specified directory
    result = subprocess.run(command, shell=True, capture_output=True, text=True, cwd=working_directory)

    # Return stdout, stderr, and returncode
    return result.stdout, result.stderr, result.returncode

stdout, stderr, returncode = compile_test_contract()

if returncode == 0:
    print("\U0001F389 \U0001F389 Contract compiled successfully! \U0001F389 \U0001F389")
else: 
    print(stderr)

Error: 
Compilation failed

