# Install Required Packages 

In [1]:
#! pip install -U --quiet langchain_community tiktoken langchain-mistralai langchainhub chromadb langchain langgraph tavily-python

## Define Functions to build data pipeline

In [106]:

from typing import Dict, List
import json

# Define the system message and context
SYSTEM_MESSAGE = """You are an AI assistant tasked with analyzing function definition and it's dependencies to determine if sufficient information is available for test code generation."""

FUNCTION_CONTEXT = """
class Car:
    def __init__(self, name, price, mileage, fuel_type):
        self.name = name
        self.price = price
        self.mileage = mileage
        self.fuel_type = fuel_type

    def calculate_emi(self, interest_rate, loan_period_years):
        '''Calculate EMI for the car.'''
        loan_period_months = Calculator().multiplication(loan_period_years , 12)
        monthly_interest_rate = Calculator().division(interest_rate, Calculator().multiplication(12 , 100))
        emi = (self.price * monthly_interest_rate * (1 + monthly_interest_rate) ** loan_period_months) / (Calculator().add(1 , monthly_interest_rate) ** loan_period_months - 1)
        return emi
"""

FUNCTION_CONTEXT = """
def create_item(item: Item):
    item_id = len(items_db) + 1
    emi = None
    if item.name == 'punch':
        punch = TataPunch()
        interest_rate = 5
        loan_period_years = 5
        emi = punch.calculate_emi(interest_rate, loan_period_years)
    elif item.name == 'tiago':
        tiago = TataTiago()
        interest_rate = 5
        loan_period_years = 5
        emi = tiago.calculate_emi(interest_rate, loan_period_years)
    elif item.name == 'indica':
        indica = TataIndica()
        interest_rate = 5
        loan_period_years = 5
        emi = indica.calculate_emi(interest_rate, loan_period_years)
    else:
        raise HTTPException(status_code=400, detail="Invalid car name")

    item.emi = emi
    items_db[item_id] = item
    return {"item_id": item_id, "item": item}


"""

In [147]:
import google.generativeai as genai

GOOGLE_API_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
genai.configure(api_key=GOOGLE_API_KEY)

def initialize_chat(system_msg):
    gemini_model = genai.GenerativeModel('gemini-pro')
    gemini_chat = gemini_model.start_chat(history=[{'role':'user',
                                                    'parts': [system_msg]
                                                    },
                                                    {'role':'model',
                                                    'parts': ['okay.']
                                                    }])

    return gemini_model, gemini_chat

def chat_with_gemini(user_text="Hello!"):
    response = gemini_chat.send_message(user_text)
    return response.text

gemini_model, gemini_chat = initialize_chat(SYSTEM_MESSAGE)

def get_llm_response(prompt):
    '''
    go gemini.
    '''
    prompt = prompt + "\n\n while generating response please strictly follow json format. otherwise it will throw JSONDecodeError error."
    response = chat_with_gemini(prompt)
    response = response.strip("```").strip("json")
    print("TYPE-------",type(response))
    return response

In [138]:
# TOBE DONE: clone repo

# TOBE DONE: get schema information w.r.t module. below cell is example.

In [139]:
# Example Index having module's schema information.
module_meta_schema = {'app.py':
                      {'abs_path':'path to file'
                      ,'function_name':{'definition':'definition goes here','description':'objective goes here..'
                                        ,'params':[{'type':'dict/list/str/int','data':'data sample if any'},]
                                        ,'returns':{'type':'dict/list/str/int','data':'data sample if any'}
                                        ,'caller_function':'caller function/method name goes here'
                                        ,"dependency_from_body": []
                                        ,"dependency_from_input": []
                                        }
                      ,'class_name':{'definition':'class definition goes here'
                                     ,'method_name':{'definition':'definition goes here','description':'objective goes here'
                                                     ,'params':[{'type':'dict/list/str/int','data':'data sample if any'},]
                                                     ,'returns':{'type':'dict/list/str/int','data':'data sample if any'}
                                                     ,'caller_function':'caller function/method name goes here'
                                                     ,"dependency_from_body": []
                                                     ,"dependency_from_input": []
                                                     }
                                     },
                      },
                      'car_emi.py':
                       {'abs_path':'path to file'
                        ,'function_name':{'definition':'definition goes here','description':'objective goes here..'
                                          ,'params':[{'type':'dict/list/str/int','data':'data sample if any'},]
                                          ,'returns':{'type':'dict/list/str/int','data':'data sample if any'}
                                          ,'caller_function':'caller function/method name goes here'
                                          ,"dependency_from_body": []
                                          ,"dependency_from_input": []
                                          }
                        ,'class_name':{'definition':'definition goes here'
                                       ,'method_name':{'definition':'definition goes here','description':'objective goes here'
                                                       ,'params':[{'type':'dict/list/str/int','data':'data sample if any'},]
                                                       ,'returns':{'type':'dict/list/str/int','data':'data sample if any'}
                                                       ,'caller_function':'caller function/method name goes here'
                                                       ,"dependency_from_body": []
                                                       ,"dependency_from_input": []
                                                       }
                                     },
                        },
                    }

# TOBE DONE: get in hierarchical order of dependencies for ease of retrieval.

def get_dependencies(attribute_name):
    '''
     attribute_name: function name or module name or class name
     return: text having all dependencies
    '''
    return module_meta_schema.find(attribute_name)

### Graph

In [140]:
from typing_extensions import TypedDict
from typing import List, Optional

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        function_context: given function definition as input context
        test_code: LLM generation
        additional_info: get additional information from user
        additional_info_rqd: does additional information required
        greeting: say bye when encounters max loop limit
        counter: keep track of numer of iteration 
    """
    object_definition : str
    test_cases: List[str]
    test_code : str
    input_mockable: str
    dependency_mockable: str
    mock_creation: str
    fixture_creation: str
    complexity: str
    additional_info_rqd : str
    additional_info : Dict[str, str]
    mock_code: str
    fixture_code: str
    prompt_history: List[str] = []
    greeting: Optional[str] = None
    counter : int = 0

In [141]:
import yaml

with open('prompt_template.yaml', 'r') as file:
    prompt_template_config = yaml.safe_load(file)

generate_test_cases_prompt = prompt_template_config.get('generate_test_cases')
generate_test_prompt = prompt_template_config.get('generate_test_code')
check_test_prompt = prompt_template_config.get('check_test_info')
mock_prompt = prompt_template_config.get('generate_mock_code')
fixture_prompt = prompt_template_config.get('generate_fixture_code')

# format prompts with object_definition, fixture_code, mock_code, previous_responses
#print(check_test_prompt)


In [142]:
#check_test_prompt.format(object_definition="def add_numbers(a: int, b: int) ->")

In [143]:
from langchain.schema import Document

### Nodes Definition goes here
def generate_test_code(state):
    print("Generate test code:::::")
    object_definition = state["object_definition"]
    additional_info = state["additional_info"]
    counter = state['counter']
    prompt_history = state.get('prompt_history')
    mock_code = state['mock_code']
    fixture_code = state['fixture_code']
    
    prompt = generate_test_prompt.format(object_definition=object_definition, fixture_code=fixture_code, 
                            mock_code=mock_code, previous_responses=prompt_history)
    print(f"Final ----- Prompt: {prompt}")
    llm_response = get_llm_response(prompt)
    state['test_code'] = llm_response

    return state

def generate_test_cases(state):
    print("Generate test cases:::::")
    object_definition = state["object_definition"]
    
    prompt = generate_test_cases_prompt.format(object_definition=object_definition)
    
    llm_response = get_llm_response(prompt)
    print("test cases......")
    print(llm_response)
    llm_response = json.loads(llm_response)
    print("after------",llm_response)
    state['test_cases'] = llm_response

    return state

def check_test_info(state):
    print("Check whether all information given to go for generating test code::::::")
    object_definition = state["object_definition"]
    counter = state["counter"]
    test_cases = state['test_cases']
    
    if counter is None:
        counter = 0
        state["prompt_history"] = []
    
    prompt = check_test_prompt.format(object_definition=object_definition, test_cases=test_cases)
    llm_response = get_llm_response(prompt)
    '''
    llm_response = {
                    "input_mockable": "no",
                    "dependency_mockable": "yes",
                    "complexity": "simple",
                    "additional_questions": {"please provide sample data":""},
                    "fixture_creation": "no",
                    "mock_creation": "yes"
                }
    '''
    llm_response = json.loads(llm_response)
    print("check_test_info.........")
    print(llm_response)
    if counter==2:
        print("&&&&&&&&&&&&----------counter=2")
        llm_response["input_mockable"] = "yes"
        llm_response["dependency_mockable"] = "yes"
        llm_response["complexity"] == "simple"
    # Parse LLM response into yes/no answers
    additional_info_rqd = 'no' if len(llm_response.get('additional_questions'))==0 else 'yes'
    
    state['additional_info_rqd'] = additional_info_rqd
    state['additional_info'] = llm_response.get('additional_questions')
    state['fixture_creation'] = llm_response.get('fixture_creation')
    state['mock_creation'] = llm_response.get('mock_creation')
    state['dependency_mockable'] = llm_response.get('dependency_mockable')
    state['input_mockable'] = llm_response.get('input_mockable')
    state['complexity'] = llm_response.get('complexity')
    state["prompt_history"].append({"asked_questions": llm_response})
    state['counter'] = counter + 1
    
    return state 

def get_user_input(state):
    print("Additional information needed. Please provide::::::::")
    
    questions = state.get("additional_info")
    print("You need to provide additional information to proceed with generating the test code.")
    for question, _ in questions.items():
        user_input = input(f"{question}: ")
        state['additional_info_rqd'] = "no"
        state['additional_info'][question] = f"{user_input}"
    state['prompt_history'].append({"collect_user_input": state['additional_info']})

    return state

def check_user_input(state):
    print("Decide based on user input::::::::")
    counter = state['counter']

    if counter >= 5:
        print("---DECISION: generate test code ended with counter exceed---")
        return "bye"
    else:
        print("---DECISION: check again whether all information given to go for generating test code::::::")
        return "check_test_info"

def bye(state):
    return {"greeting":"The graph has finished"}

### Edges Definition goes here
def decide_to_generate(state):
    print("Decide to generate or ask user::::::")
    #responses = state.get("responses", {})
    #state['prompt_history'].append({"process_responses": responses})

    # Decision-making based on responses
    if (state.get("input_mockable") == "yes" and
        state.get("dependency_mockable") == "yes" and
        state.get("complexity") == "simple"):
        print("---DECISION: all information given to go for generating test code::::::::")
        return "generate_fixtures"
    else:
        print("---DECISION: take user input---")
        return "get_user_input"

def generate_fixtures(state):
    print("generate_fixtures::::::::")
    if state.get("fixture_creation") == "yes":
        previous_responses = "\n".join([f"{k}: {v}" for entry in state['prompt_history'] for k, v in entry.items()])
        prompt = fixture_prompt.format(object_definition=state["object_definition"], previous_responses=previous_responses)
        llm_response = get_llm_response(prompt)
        state['prompt_history'].append({"generate_fixtures": llm_response})
        state["fixture_code"] = llm_response
    return state#"generate_mock_code"

def generate_mock_code(state):
    print("generate_mock_code::::::::")
    if state.get("mock_creation") == "yes":
        previous_responses = "\n".join([f"{k}: {v}" for entry in state['prompt_history'] for k, v in entry.items()])
        prompt = mock_prompt.format(object_definition=state["object_definition"], previous_responses=previous_responses)
        llm_response = get_llm_response(prompt)
        state['prompt_history'].append({"generate_mock_code": llm_response})
        state["mock_code"] = llm_response
    return state#"generate_test_code"

### Build Graph

In [144]:
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("generate_test_cases", generate_test_cases)
workflow.add_node("check_test_info", check_test_info)
workflow.add_node("get_user_input", get_user_input)
workflow.add_node("generate_fixtures", generate_fixtures)
workflow.add_node("generate_mock_code", generate_mock_code)
workflow.add_node("generate_test_code", generate_test_code)
workflow.add_node("bye", bye)

# Build graph
workflow.set_entry_point("generate_test_cases")
workflow.add_edge("generate_test_cases", "check_test_info")
workflow.add_conditional_edges(
    "check_test_info",
    decide_to_generate,
    {
        "get_user_input": "get_user_input",
        "generate_fixtures": "generate_fixtures",
    },
)
workflow.add_edge("generate_fixtures", "generate_mock_code")
workflow.add_edge("generate_mock_code", "generate_test_code")
workflow.add_conditional_edges(
    "get_user_input",
    check_user_input,
    {
        "check_test_info": "check_test_info",
        "bye": "bye"
    }
)
workflow.add_edge("generate_test_code", END)
workflow.add_edge('bye', END)

# Compile
app = workflow.compile()

In [145]:
# Verify Control Flow
app.channels

{'object_definition': <langgraph.channels.last_value.LastValue at 0x24750b6eb10>,
 'test_cases': <langgraph.channels.last_value.LastValue at 0x24750b6e420>,
 'test_code': <langgraph.channels.last_value.LastValue at 0x24750b6c9b0>,
 'input_mockable': <langgraph.channels.last_value.LastValue at 0x24750b6e7b0>,
 'dependency_mockable': <langgraph.channels.last_value.LastValue at 0x24750b6e5d0>,
 'mock_creation': <langgraph.channels.last_value.LastValue at 0x24750b6f7a0>,
 'fixture_creation': <langgraph.channels.last_value.LastValue at 0x24750b6f710>,
 'complexity': <langgraph.channels.last_value.LastValue at 0x24750b6e990>,
 'additional_info_rqd': <langgraph.channels.last_value.LastValue at 0x24750b6e660>,
 'additional_info': <langgraph.channels.last_value.LastValue at 0x24750b6ede0>,
 'mock_code': <langgraph.channels.last_value.LastValue at 0x24750b6e930>,
 'fixture_code': <langgraph.channels.last_value.LastValue at 0x24750b6e120>,
 'prompt_history': <langgraph.channels.last_value.LastVal

### Execution

In [146]:
from pprint import pprint

inputs = {"object_definition": FUNCTION_CONTEXT}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Finished running: {key}:")

if value.get('generation',None):
    pprint(value["generation"])
else:
    print("Done.")

Generate test cases:::::
TYPE------- <class 'str'>
test cases......

{
  "test_case_scenarios": [
    {
      "description": "Valid car name - punch",
      "input": {
        "item": {
          "name": "punch"
        }
      },
      "expected_output": {
        "item_id": 1,
        "item": {
          "name": "punch",
          "emi": 10000  # Assuming the EMI is 10000 for this example
        }
      }
    },
    {
      "description": "Valid car name - tiago",
      "input": {
        "item": {
          "name": "tiago"
        }
      },
      "expected_output": {
        "item_id": 2,
        "item": {
          "name": "tiago",
          "emi": 10000  # Assuming the EMI is 10000 for this example
        }
      }
    },
    {
      "description": "Valid car name - indica",
      "input": {
        "item": {
          "name": "indica"
        }
      },
      "expected_output": {
        "item_id": 3,
        "item": {
          "name": "indica",
          "emi": 10000  # Assu

JSONDecodeError: Expecting ',' delimiter: line 15 column 25 (char 281)

In [105]:
print(value['test_code'])

```python
import unittest

# Mock the Calculator class
mock_calculator = Mock()

# Mock the multiplication method of the Calculator class
mock_calculator.multiplication.return_value = 12

# Mock the division method of the Calculator class
mock_calculator.division.return_value = 0.08333333333333333

class CarTest(unittest.TestCase):

    def setUp(self):
        # Create an instance of the Car class
        self.car = Car(name="Car A", price=100000, mileage=10, fuel_type="Petrol")

        # Mock the calculate_emi method of the Car class
        self.car.calculate_emi = Mock(return_value=2155.39)

    def test_calculate_emi(self):
        # Call the calculate_emi method of the Car class
        emi = self.car.calculate_emi(interest_rate=10, loan_period_years=5)

        # Assert that the calculate_emi method was called with the correct arguments
        self.car.calculate_emi.assert_called_with(interest_rate=10, loan_period_years=5)

        # Assert that the multiplication method of th