In [2]:
import re

from langchain_openai import ChatOpenAI
from langchain_experimental.tools import PythonREPLTool
from langchain_core.prompts import ChatPromptTemplate

from extract import (
    extract_function_names,
    extract_function_by_name,
    extract_imports,
    is_valid_python,
)

python_repl = PythonREPLTool()

llm = ChatOpenAI(temperature=0, model="gpt-4o")

In [7]:
unit_test_system_prompt_string = """
You are a specialized assistant designed to create unit tests for functions, particularly those involving regular expressions (regex). \
    Your task is to generate concise, self-contained unit tests based on a natural language description of the function's intended behavior.

Instructions:
1. Analyze the given description carefully.
2. Create distinct unit tests that cover all important aspects of the function's expected behavior.
3. Only write unit test functions. Do not write a test Class.
4. If you need to use a unit test framework, use pytest.
5. Each test should focus on a single aspect or scenario.
6. Use descriptive test names that clearly indicate what is being tested.
7. Include assertions that accurately check the expected outcomes.
8. Do not write the actual function implementation or code to run the tests.
9. Ensure tests cover edge cases, invalid inputs, and typical use cases.
10. For regex functions, consider testing pattern matching, capturing groups, and boundary conditions.
11. Keep the tests concise and to the point.

Provide your response as a Python code block containing only the unit tests.

Example input: "Create tests for a function that extracts all email addresses from a given text."

Your task is to generate appropriate unit tests based on similar natural language descriptions.
"""

unit_test_prompt = ChatPromptTemplate.from_messages(
    [("system", unit_test_system_prompt_string), ("user", "{input}")]
)

unit_test_chain = unit_test_prompt | llm
unit_test_llm_message = unit_test_chain.invoke(
    {
        "input": "Extract emails (including instances of 'name at email dot com') from a text."
    }
)

In [12]:
# print(unit_test_llm_message.content)

In [16]:
def extract_python_code(text):
    pattern = r'```python\n(.*?)```'
    code_block = re.findall(pattern, text, re.DOTALL)[0]
    is_valid_python(code_block)
    return code_block

unit_test_code_as_string = extract_python_code(unit_test_llm_message.content)
print(unit_test_code_as_string)

import re
import pytest

def test_extract_emails_standard_format():
    text = "Please contact us at support@example.com for further information."
    expected = ["support@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_multiple_standard_format():
    text = "Send an email to john.doe@example.com or jane.doe@sample.org."
    expected = ["john.doe@example.com", "jane.doe@sample.org"]
    assert extract_emails(text) == expected

def test_extract_emails_obfuscated_format():
    text = "You can reach me at john.doe at example dot com."
    expected = ["john.doe@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_mixed_formats():
    text = "Contact us at support@example.com or john.doe at example dot com."
    expected = ["support@example.com", "john.doe@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_no_emails():
    text = "There are no email addresses in this text."
    expected = []
   

In [17]:
unit_test_function_names = extract_function_names(code_string=unit_test_code_as_string)
unit_test_function_names

['test_extract_emails_standard_format',
 'test_extract_emails_multiple_standard_format',
 'test_extract_emails_obfuscated_format',
 'test_extract_emails_mixed_formats',
 'test_extract_emails_no_emails',
 'test_extract_emails_edge_case_empty_string',
 'test_extract_emails_edge_case_special_characters',
 'test_extract_emails_with_subdomains',
 'test_extract_emails_with_numbers',
 'test_extract_emails_with_hyphens']

In [18]:
unit_test_imports = extract_imports(unit_test_code_as_string)
unit_test_imports

['import re', 'import pytest']

In [19]:
function_writer_system_prompt_string = """
You are a Python code generation assistant. Your task is to create a Python function that satisfies all the provided pytest unit tests. Follow these guidelines:

1. Analyze the given unit tests carefully to understand the function's required behavior.
2. Write a single Python function that passes all the provided tests.
3. Use type hints for parameters and return values.
4. Include a clear and concise docstring explaining the function's purpose and parameters.
5. Follow Python best practices and PEP 8 style guidelines.
6. Do not include comments within the function body.
7. Ensure the function handles all edge cases and scenarios covered in the tests.
8. If the tests imply the use of regular expressions, import the 're' module and use it appropriately.
9. Provide only the function definition and its implementation, nothing else.

Your response should be a Python code block containing only the requested function.
"""

function_writer_prompt = ChatPromptTemplate.from_messages(
    [("system", function_writer_system_prompt_string), ("user", "{input}")]
)

function_writer_chain = function_writer_prompt | llm
function_writer_llm_message = function_writer_chain.invoke(
    {
        "input": unit_test_code_as_string
    }
)

In [21]:
# print(function_writer_llm_message.content)
target_function_code_as_string = extract_python_code(function_writer_llm_message.content)
print(target_function_code_as_string)

import re
from typing import List

def extract_emails(text: str) -> List[str]:
    """
    Extracts email addresses from the given text. Supports both standard and obfuscated formats.

    Args:
        text (str): The input text containing email addresses.

    Returns:
        List[str]: A list of extracted email addresses.
    """
    standard_emails = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text)
    obfuscated_emails = re.findall(r'\b[A-Za-z0-9._%+-]+ at [A-Za-z0-9.-]+ dot [A-Z|a-z]{2,}\b', text)
    obfuscated_emails = [email.replace(' at ', '@').replace(' dot ', '.') for email in obfuscated_emails]
    return standard_emails + obfuscated_emails



In [25]:
target_function_name = extract_function_names(code_string=target_function_code_as_string)
assert len(target_function_name) == 1
target_function_name = target_function_name[0]

In [23]:
print(unit_test_code_as_string + "\n" + target_function_code_as_string)

import re
import pytest

def test_extract_emails_standard_format():
    text = "Please contact us at support@example.com for further information."
    expected = ["support@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_multiple_standard_format():
    text = "Send an email to john.doe@example.com or jane.doe@sample.org."
    expected = ["john.doe@example.com", "jane.doe@sample.org"]
    assert extract_emails(text) == expected

def test_extract_emails_obfuscated_format():
    text = "You can reach me at john.doe at example dot com."
    expected = ["john.doe@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_mixed_formats():
    text = "Contact us at support@example.com or john.doe at example dot com."
    expected = ["support@example.com", "john.doe@example.com"]
    assert extract_emails(text) == expected

def test_extract_emails_no_emails():
    text = "There are no email addresses in this text."
    expected = []
   

In [28]:
test_function_pass_status_and_message = []

for unit_test_name in unit_test_function_names:
    # Example code to execute
    code_to_run = f"""
{unit_test_code_as_string}

{target_function_code_as_string}

try:
    {unit_test_name}()
    message = "test passed"
except AssertionError as e:
    message = "Assertion failed: " + str(e)
except Exception as e:
    message = "Error: " + str(e)

print(message)
"""

    # Execute the code
    # print(code_to_run)
    output = python_repl.run(code_to_run)
    test_function_pass_status_and_message.append(unit_test_name + "(): " + output)

In [30]:
print(test_function_pass_status_and_message)

def all_end_with(string_list, pattern):
    return all(s.endswith(pattern) for s in string_list)

all_end_with(string_list=test_function_pass_status_and_message, pattern="test passed\n")

['test_extract_emails_standard_format(): test passed\n', 'test_extract_emails_multiple_standard_format(): test passed\n', 'test_extract_emails_obfuscated_format(): test passed\n', 'test_extract_emails_mixed_formats(): test passed\n', 'test_extract_emails_no_emails(): test passed\n', 'test_extract_emails_edge_case_empty_string(): test passed\n', 'test_extract_emails_edge_case_special_characters(): test passed\n', 'test_extract_emails_with_subdomains(): test passed\n', 'test_extract_emails_with_numbers(): test passed\n', 'test_extract_emails_with_hyphens(): test passed\n']


True