# Iteratively improve responses using reflection

In this notebook, we use an LLM to reflect on responses generated and try to improve the result iteratively. See our previous work on evaluation [here](./01.coolstore.ipynb). 

For reflection, we use some of the ideas from evaluation work. But, there are changes to input data. We compare the ASTs of updated file and the old file. We structure the comparison output / diff in JSON format. The idea is to provide precise information about changes to the model, in a format it understands. In our experience, this type of input has been more effective than source code files or file diff.

## Pre-requisites

To run cells in this notebook, setup and activate virtual env at [project root](../../).

Also make sure you have installed the Kai module in the venv (from project root):

```sh
pip install -e .
```

We can use models provided by IBM, AWS Bedrock and OpenAI. 

_Configure model you want to use by updating the `MODEL` variable in the next cell_.

Depending on models you want to use, ensure you include the keys in `.env` file in Kai base dir:

```sh
export GENAI_KEY=<your-ibm-key>
export OPENAI_API_KEY=<your-openai-key>
```

For Claude Sonnet, we use AWS Bedrock. Configure your AWS keys in the default AWS profile for authentication. This file is usually located at `~/.aws/credentials`.

Once you have configured the `.env` file, run the following cell to initiate a model connection (change _MODEL_ variable to use another model):

In [1]:
%load_ext dotenv
%dotenv

# some shorthands we can use in our experiments for different models
META_LLAMA_70b = ("ChatIBMGenAI", "meta-llama/llama-3-70b-instruct")
META_LLAMA_8b = ("ChatIBMGenAI", "meta-llama/llama-3-8b-instruct")
MIXTRAL = ("ChatIBMGenAI", "mistralai/mixtral-8x7b-instruct-v01")
GPT_4 = ("ChatOpenAI", "gpt-4o")
GPT_3 = ("ChatOpenAI", "gpt-3.5-turbo")
CLAUDE_SONNET = ("ChatBedrock", "anthropic.claude-3-5-sonnet-20240620-v1:0")

# model used in experiments
MODEL = META_LLAMA_8b

# DO NOT EDIT BELOW, ONLY EDIT 'MODEL' VAR ABOVE TO CHANGE MODEL
from kai.service.llm_interfacing.model_provider import ModelProvider, KaiConfigModels
def setup_model(model: tuple[str, str]) -> ModelProvider:
    provider = model[0]
    model_id = model[1]
    config = KaiConfigModels(provider=provider)
    match provider:
        case "ChatIBMGenAI":
            max_tokens = {
                META_LLAMA_8b: "1536",
                META_LLAMA_70b: "2048",
            }.get(MODEL, "2048")
            config.args["parameters"] = {
                "max_new_tokens": max_tokens,
            }
            config.args["model_id"] = model_id
        case "ChatBedrock":
            config.args["model_id"] = model_id
        case _:
            config.args["model"] = model_id
    return ModelProvider(config=config)
model = setup_model(MODEL)

### Load test data

Following cell clones the [coolstore](https://github.com/konveyor-ecosystem/coolstore) application. It also contains some common code we need later:

In [2]:
import re
import os
import sys
import errno
from git import Repo
from datetime import datetime
from collections import Counter
from kai.analyzer_types import Report
from kai.evaluation import BenchmarkExample
from kai.service.incident_store import Application
from kai.analyzer_types import ExtendedIncident
from langchain_core.messages.ai import AIMessage
from jinja2 import Environment, FileSystemLoader
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.prompts.chat import ChatPromptTemplate

def ensure_dirs(dir):
    try:
        os.makedirs(dir)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

clean_path = lambda x: x.rstrip(os.path.sep)

def clone_coolstore(branch: str, path: str):
    try:
        Repo.clone_from("https://github.com/konveyor-ecosystem/coolstore", 
            depth=1, single_branch=True, branch=branch, to_path=path)
    except Exception as e:
        if "already exists" not in str(e):
            print("fatal error cloning repo")
            sys.exit(1)

def load_report_data(output_file: str, app_path: str, exclude_list: list[str] = []) -> dict[str, BenchmarkExample]:
    print(f"loading report data from {output_file}")
    report = Report.load_report_from_file(output_file)
    files = report.get_impacted_files()

    # we filter out filepaths for dependencies
    to_delete = []
    excld_lst = ['root/.m2']
    excld_lst.extend(exclude_list)
    for k in files:
        for exc in excld_lst:
            if exc != '' and (k.startswith(exc) or exc in k): to_delete.append(k)
    for d in to_delete: del(files[d])

    examples: dict[str, BenchmarkExample] = {}
    for f in files:
        original_content = ""
        expected_content = ""
        with open(os.path.join(app_path, f), "r") as fl: original_content = fl.read()
        examples[f] = BenchmarkExample(
            application=Application(
                application_name=os.path.basename(clean_path(app_path)),
                current_branch="main",
                repo_uri_local=app_path,
                generated_at=datetime.strptime("24/05/09 19:32:00", "%y/%m/%d %H:%M:%S"),
                repo_uri_origin="https://github.com/konveyor-ecosystem/coolstore",
                current_commit="aa"
            ),
            expected_file=expected_content,
            incidents=files[f],
            original_file=original_content,
            name=f,
            report=report,
        )
    return examples

def summarize_incidents(incidents: list[ExtendedIncident]):
    print(f"Found {len(incidents)} incidents:")
    for inc, total in Counter([inc.violation_name for inc in incidents ]).items():
        print(f"{inc} - {total}")
    pass

def get_prompt(sub_path: str, params: any) -> str:
    template_env = Environment(loader=
                               FileSystemLoader(searchpath=os.path.join(".", "templates", "agent-prompts")))    
    return template_env.get_template(sub_path).render(params)

def parse_kai_fix(ai_message: AIMessage) -> tuple[str, str]:
    match_updated_file = re.search(r'[##|\*\*] [U|u]pdated [F|f]ile\s+.*?```\w+\n([\s\S]*?)```', ai_message.content, re.DOTALL)
    if not match_updated_file:
        raise Exception("updated file content not found")
    match_reasoning = re.search(r'[##|\*\*] ?[R|r]easoning\s+(.*)?[##|\*\*] ?[U|u]pdated [F|f]ile', ai_message.content, re.DOTALL)
    if not match_reasoning:
        raise Exception("reasoning not found")
    return match_updated_file.group(1).strip(), match_reasoning.group(1).strip()

ensure_dirs("./data/apps/coolstore/javaee/")
clone_coolstore("main", "./data/apps/coolstore/javaee/")


Now we load our test data in memory. Comment-out/in `INPUT_FILE` variable below to change the file we want to fix:

In [3]:

report_coolstore = load_report_data(
    "./analysis_output.yaml", 
    "./data/apps/coolstore/javaee/")
INPUT_FILE = report_coolstore["src/main/java/com/redhat/coolstore/service/ShoppingCartOrderProcessor.java"]
# INPUT_FILE = report_coolstore['src/main/java/com/redhat/coolstore/service/OrderServiceMDB.java']
summarize_incidents(INPUT_FILE.incidents)

loading report data from ./analysis_output.yaml
Found 10 incidents:
javax-to-jakarta-import-00001 - 5
ee-to-quarkus-00000 - 1
jms-to-reactive-quarkus-00040 - 2
jms-to-reactive-quarkus-00050 - 2


## Generate a fix

First, we generate a fix for the given input file:

In [4]:
sys_msg_fix_gen = SystemMessage(
        content="""
You are an AI Assistant trained on migrating enterprise JavaEE code to Quarkus.
I will give you a JavaEE file which I want to migrate to Quarkus.
I will provide you with static source code analysis information highlighting issues that need to be addressed.
Fix only the problem described. Other problems will be solved in subsequent steps so it is unnecessary to handle them now.
Before attempting to migrate the code to Quarkus reason through what changes are required and why.
Pay attention to changes you make and impacts to external dependencies in the pom.xml as well as changes to imports we need to consider.
Remember when updating or adding annotations that the class must be imported.
As you make changes that impact the pom.xml or imports, be sure you explain what needs to be updated.
After you have shared your step by step thinking, provide a full output of the updated file.
If you are given a feedback, address the concerns raised in feedback and respond with an updated file.
Structure your output in Markdown format such as:

## Reasoning
Write the step by step reasoning in this markdown section. If you are unsure of a step or reasoning, clearly state you are unsure and why.

## Updated File
```java
// Write the updated file for Quarkus in this section. If the file should be removed, make the content of the updated file a comment explaining it should be removed.
```
""")

user_msg_fix_gen = HumanMessage(content=get_prompt("fix.prompt.jinja", {
                                    "incidents": INPUT_FILE.incidents,
                                    "src_file_contents": INPUT_FILE.original_file,
                               }))

chat_fix_gen = [sys_msg_fix_gen, user_msg_fix_gen]

result_fix_gen_initial = (ChatPromptTemplate(chat_fix_gen) | model.llm).invoke({})
chat_fix_gen.append(AIMessage(content=result_fix_gen_initial.content))
print("## Original file")
print(INPUT_FILE.original_file)
print("## Issues")
import json
print(json.dumps(list(set([inc.message for inc in INPUT_FILE.incidents])), indent=4))
print("## Updated file")
print(parse_kai_fix(result_fix_gen_initial)[0])
print("\n\n## Reasoning")
print(parse_kai_fix(result_fix_gen_initial)[1])


## Original file
package com.redhat.coolstore.service;

import java.util.logging.Logger;
import javax.ejb.Stateless;
import javax.annotation.Resource;
import javax.inject.Inject;
import javax.jms.JMSContext;
import javax.jms.Topic;

import com.redhat.coolstore.model.ShoppingCart;
import com.redhat.coolstore.utils.Transformers;

@Stateless
public class ShoppingCartOrderProcessor  {

    @Inject
    Logger log;


    @Inject
    private transient JMSContext context;

    @Resource(lookup = "java:/topic/orders")
    private Topic ordersTopic;

    
  
    public void  process(ShoppingCart cart) {
        log.info("Sending order from processor: ");
        context.createProducer().send(ordersTopic, Transformers.shoppingCartToJson(cart));
    }



}

## Issues
[
    "Replace the `javax.ejb` import statement with `jakarta.ejb`",
    "Replace the `javax.inject` import statement with `jakarta.inject`",
    "Replace the `javax.annotation` import statement with `jakarta.annotation`",
    "JMS `T

In [5]:

import json

from kai.reactive_codeplanner.agents.reflection_agent import ReflectionAgent, ReflectionTask
from kai.reactive_codeplanner.agents.ast_diff.parser import Language

# reflection supports multiple iterations
# let's start with 1
agent = ReflectionAgent(llm=model.llm, silent=False, iterations=1)

task = ReflectionTask(
    file_path=INPUT_FILE.name,
    original_file=INPUT_FILE.original_file,
    updated_file=parse_kai_fix(result_fix_gen_initial)[0],
    reasoning=parse_kai_fix(result_fix_gen_initial)[1],
    issues=set([inc.message for inc in INPUT_FILE.incidents])
)

print(f"{'*'*10}(fix-gen -> user)\n{task.updated_file}")
task_results = agent.execute_task(task)
print(task_results)

**********(fix-gen -> user)
package com.redhat.coolstore.service;

import java.util.logging.Logger;
import jakarta.annotation.Resource;
import jakarta.ejb.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.jms.JMSContext;
import jakarta.jms.Topic;
import org.eclipse.microprofile.reactive.messaging.Channel;
import org.eclipse.microprofile.reactive.messaging.Emitter;

import com.redhat.coolstore.model.ShoppingCart;
import com.redhat.coolstore.utils.Transformers;

@ApplicationScoped
public class ShoppingCartOrderProcessor {

    @Inject
    Logger log;

    @Inject
    private Emitter<String> ordersEmitter;

    @Resource(lookup = "java:/topic/orders")
    private Channel<String> ordersChannel;

    public void process(ShoppingCart cart) {
        log.info("Sending order from processor: ");
        ordersEmitter.send(Transformers.shoppingCartToJson(cart));
    }
}
**********(user -> reflection)

You will be given a list of migration issues found in an old Java file in JSON fo