In [1]:
import json
import os
import random
import subprocess

from anthropic import Anthropic
from dotenv import load_dotenv, find_dotenv
from IPython.display import display, Markdown
from openai import OpenAI
import tiktoken
from functions import *

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Literal, List

from langchain.schema.runnable import RunnablePassthrough

from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.chains.router.llm_router import LLMRouterChain
from langchain.chains.router import MultiPromptChain
from langchain.chains.router.llm_router import LLMRouterChain,RouterOutputParser

In [2]:
query = "I want a super simple hook that has a counter every time someone makes swap"

find_relevant_hooks = """You are an expert at routing a user question to the appropriate data source.
Based on the intented hooks, choose the most relevant datasources."""

In [10]:
with open("hook_examples.json", 'r') as file:
    hook_examples_json = json.load(file)

DatasourceOptions = Literal[*tuple(hook_examples_json.keys())]

class FindRelevantHooks(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasources: List[DatasourceOptions] = Field(
        ...,
        description=f"Given a user question choose which datasources would be most relevant for answering their question. Available datasources: {json.dumps(hook_examples_json, indent=2)}",
    )

In [11]:
from langchain.chains import create_structured_output_runnable

llm = ChatOpenAI(temperature=0.9, model="gpt-4")

find_relevant_hooks_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", find_relevant_hooks),
        ("human", "{question}"),
    ]
)

chain_one = create_structured_output_runnable(
    FindRelevantHooks,
    llm=llm, 
    prompt=find_relevant_hooks_prompt, 
)

In [12]:
def read_hooks(x):
    folder_path = '../foundry_hook_playground/src/examples/'
    hooks = ""
    for i, hook in enumerate(x['datasources']):
        link_words = f"\n\n example {i}:\n -------------------------\n\n"
        hooks += link_words
        with open(folder_path + hook, 'r') as file:
            contents = file.read()
            hooks += contents
    return {"hooks": hooks, "datasources": x['datasources'], "question": x['question']}

prompt_instructions = read_file('instructions.txt')

chain_two_template = ChatPromptTemplate.from_messages([
    ("system", prompt_instructions),
    ("human", """
     Based on these instructions, generate a Solidity contract for the following question: {question}\n\n
     I only want you to generate the solidity code, nothing else. Write directly the code.\n
     Here are some relevant hook examples:\n{hooks}\n\n
     Generated Solidity Code:
    """
    )
])

def combine_inputs(inputs):
    return {
        "hooks": inputs["hooks"],
        "question": inputs["question"]
    }

# Create chain_two
chain_two = chain_two_template | llm

In [13]:
def extract_and_save_code(x):
    # Extract the Solidity code from the content
    match = re.search(r'```solidity\n(.*?)```', x.content, re.DOTALL)
    if match:
        code = match.group(1)
    else:
        code = x.content  # If no code block is found, use the entire content

    # Save the code to a file
    with open( "../foundry_hook_playground/src/generated/"+'GeneratedContract.sol', 'w') as file:
        file.write(code)

    return code  # Return the code for printing


overall_chain = (
    {"question": RunnablePassthrough()} | 
    chain_one | 
    (lambda x: {"datasources": x.datasources, "question": RunnablePassthrough()}) |
    read_hooks |
    (lambda x: {
        "hooks": x["hooks"], 
        "question": x["question"]
    }) |
    combine_inputs |
    chain_two |
    extract_and_save_code
)

# Use the chain
result = overall_chain.invoke({"question": query})
print("Generated Solidity Code:\n")
print(result)

Generated Solidity Code:

// SPDX-License-Identifier: MIT
pragma solidity ^0.8.24;

import {BaseHook} from "v4-periphery/BaseHook.sol";
import {Hooks} from "v4-core/src/libraries/Hooks.sol";
import {IPoolManager} from "v4-core/src/interfaces/IPoolManager.sol";
import {PoolKey} from "v4-core/src/types/PoolKey.sol";
import {PoolId, PoolIdLibrary} from "v4-core/src/types/PoolId.sol";
import {BalanceDelta} from "v4-core/src/types/BalanceDelta.sol";
import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "v4-core/src/types/BeforeSwapDelta.sol";

contract HookExample is BaseHook {
    using PoolIdLibrary for PoolKey;

    mapping(PoolId => uint256 count) public beforeSwapCount;
    mapping(PoolId => uint256 count) public afterSwapCount;

    mapping(PoolId => uint256 count) public beforeAddLiquidityCount;
    mapping(PoolId => uint256 count) public beforeRemoveLiquidityCount;

    constructor(IPoolManager _poolManager) BaseHook(_poolManager) {}

    function getHookPermissions() public pure ov

In [15]:
def create_unit_tests(contract_code):
    unit_test_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an expert Solidity developer. Create unit tests for the following smart contract."),
        ("human", """
         Here's the smart contract:\n\n{contract_code}\n\n
         Please write comprehensive unit tests for this contract using Foundry's testing framework.
         I only want you to write the unit tests, nothing else. Write directly the unit tests.
         """
         )
    ])
    
    unit_test_chain = unit_test_prompt | ChatOpenAI(temperature=0.7)
    unit_tests = unit_test_chain.invoke({"contract_code": contract_code})

    match = re.search(r'```solidity\n(.*?)```', unit_tests.content, re.DOTALL)
    if match:
        code = match.group(1)
    else:
        code = unit_tests.content  # If no code block is found, use the entire content
    
    # Save unit tests to a file
    with open('../foundry_hook_playground/src/test/GeneratedContractTest.t.sol', 'w') as file:
        file.write(code)
    
    return code

def fix_contract(contract_code, error_message):
    fix_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an expert Solidity developer. Fix the compilation error in the following smart contract."),
        ("human", "Here's the smart contract with a compilation error:\n\n{contract_code}\n\nError message:\n{error_message}\n\nPlease provide the corrected contract.")
    ])
    
    fix_chain = fix_prompt | ChatOpenAI(temperature=0.7)
    fixed_contract = fix_chain.invoke({"contract_code": contract_code, "error_message": error_message})
    
    # Save fixed contract to a file
    with open('../foundry_hook_playground/src/generated/GeneratedContract.sol', 'w') as file:
        file.write(fixed_contract.content)
    
    return fixed_contract.content

In [14]:
def run_tests():
    command = "forge test --match-path src/test/GeneratedContractTest.t.sol -vvv"
    working_directory = "../foundry_hook_playground"

    result = subprocess.run(command, shell=True, capture_output=True, text=True, cwd=working_directory)
    return result.stdout, result.stderr, result.returncode

stdout, stderr, returncode = compile_contract('GeneratedContract.sol')
if returncode == 0:
    print("\U0001F389 \U0001F389 Contract compiled successfully! \U0001F389 \U0001F389")
    print("\nGenerating unit tests...")
    unit_tests = create_unit_tests(result)
    print("Unit tests have been generated and saved to '../foundry_hook_playground/test/GeneratedContractTest.t.sol'")
    print("\nUnit Tests:\n")
    print(unit_tests)
    
    print("\nRunning unit tests...")
    test_stdout, test_stderr, test_returncode = run_tests()
    if test_returncode == 0:
        print("\U0001F389 \U0001F389 All tests passed successfully! \U0001F389 \U0001F389")
        print("\nTest output:")
        print(test_stdout)
    else:
        print("Some tests failed. Please review the test output:")
        print(test_stderr)
else:
    print("Compilation failed. Attempting to fix the contract...")
    fixed_contract = fix_contract(result, stderr)
    print("\nFixed Contract:\n")
    print(fixed_contract)
    print("\nAttempting to compile the fixed contract...")
    
    # Compile the fixed contract
    stdout, stderr, returncode = compile_contract('../foundry_hook_playground/src/generated/GeneratedContract.sol')
    
    if returncode == 0:
        print("\U0001F389 \U0001F389 Fixed contract compiled successfully! \U0001F389 \U0001F389")
        print("\nGenerating unit tests for the fixed contract...")
        unit_tests = create_unit_tests(fixed_contract)
        print("Unit tests have been generated and saved to '../foundry_hook_playground/test/GeneratedContractTest.t.sol'")
        print("\nUnit Tests:\n")
        print(unit_tests)
    else:
        print("Fixed contract still fails to compile. Please review the contract manually.")
        print("Compilation Error:")
        print(stderr)

🎉 🎉 Contract compiled successfully! 🎉 🎉

Generating unit tests...
Unit tests have been generated and saved to '../foundry_hook_playground/test/GeneratedContractTest.t.sol'

Unit Tests:

// SPDX-License-Identifier: MIT
pragma solidity ^0.8.24;

import "https://github.com/Foundry-DAO/foundry-contracts/blob/main/contracts/testing/contracts/Assert.sol";
import "https://github.com/Foundry-DAO/foundry-contracts/blob/main/contracts/testing/contracts/OneTime.sol";

contract HookExampleTest {
    HookExample hookExample;
    
    function beforeEach() public {
        hookExample = new HookExample();
    }
    
    function testBeforeSwap() public {
        PoolKey memory key = PoolKey({poolAddress: address(0), fee: 0});
        IPoolManager.SwapParams memory swapParams;
        bytes memory data;
        
        hookExample.beforeSwap(address(0), key, swapParams, data);
        
        Assert.equal(hookExample.beforeSwapCount[key.toId()], 1, "BeforeSwap count should increase by 1");
    }
  

In [2]:
def run_tests():
    command = "forge test --match-path src/test/GeneratedContractTest.t.sol -vvv"
    working_directory = "../foundry_hook_playground"

    result = subprocess.run(command, shell=True, capture_output=True, text=True, cwd=working_directory)
    return result.stdout, result.stderr, result.returncode

print("\nRunning unit tests...")
test_stdout, test_stderr, test_returncode = run_tests()
if test_returncode == 0:
    print("\U0001F389 \U0001F389 All tests passed successfully! \U0001F389 \U0001F389")
    print("\nTest output:")
    print(test_stdout)
else:
    print("Some tests failed. Please review the test output:")
    print(test_stderr)


Running unit tests...
Some tests failed. Please review the test output:
Error: 
Compilation failed

