# Search-Engine-Integrated Multi-Expert Inference (SEIMEI)

In [None]:
#!pip install datasets
!pip install transformers
#!pip install accelerate
!pip install numpy, pandas
!pip install matplotlib
!pip install sentence_transformers
!pip install huggingface_hub
!pip install flask

# For vLLM
!pip install vllm
!pip install ray
!pip install packaging
!pip install typing

In [2]:
from huggingface_hub import login
login(token="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")

# Preparation

In [None]:
from Prepare import Prepare

data_path = "../miller"
save_path = "./processed/miller"

# designate all the files with 'extensions' inside 'folder_path'
file_info = [
    {"folder_path":"", "extensions":[".tex"]},
#    {"folder_path":"src", "extensions":[".f90"]},
#    {"folder_path":"run", "extensions":["", ".q"]},
#    {"folder_path":"lib", "extensions":[".f90"]},
#    {"folder_path":"", "extensions":[".txt",".md"]},
]


# about where the key starts to split the text

# index : words to be where text should be split
# first element(0 to 1): process_text_size * element is the start point of the key splitting. the samller the element is, the more likely it is for the key to split the text.
# second element(0 or 1): the first element should become   if 0: <text1><key> | <text2>,  if 1: <text1> | <key><text2>
rules = [
    {
        #"SUBROUTINE " : 1,
        #"class " : 1,
        "\\\\section*" : 1,
    },

    {
        "\\\\subsection*" : 1,
        #"def " : 1,
        #"void " : 1,
    },

    {
        #"if " : 1,
        #"end if" : 0,
        "\\\\begin{center}" : 1,
        "\\\\end{gather*}" : 0,
        "\\\\end{align*}" : 0,   
        "\\\\end{equation*}" : 0,
        "\\\\end{enumerate}" : 0,
    },

    {
        #"else " : 1,
        #"elif " : 1,
    },
    

    {
        "\n\n" : 0,
        "<0x0A><0x0A>" : 0,
        "\x0A\x0A" : 0,
    },

    {
        "\n" : 0,
        "<0x0A>" : 0,
        "\x0A" : 0,
    },
]

prepare = Prepare(
    database_path = data_path,
    save_path = save_path,
    rules = rules, 
    file_info=file_info, 
    model_name = "gpt2",
    max_tokens = 10000,
    min_tokens = 3000,
)

In [None]:
prepare.make_chunks()

## Gather all save_paths

In [None]:
save_dirs = [
    "./processed/gkv-code",
    "./processed/gkw-manual",
    "./processed/miller",
]

new_save_dir = "./processed"

prepare.gather_save_dirs(save_dirs, new_save_dir)

## Manual modifiaction

In [None]:
# If you wanna modify the chunk manually, run the code below
prepare.modify_chunks_manually()

In [None]:
# After modifying the chunks, remember to run the code below
prepare.finish_modifying()

## Examples of rules

In [None]:
# about where the key starts to split the text
# index : words to be where text should be split
# first element(0 to 1): process_text_size * element is the start point of the key splitting. the samller the element is, the more likely it is for the key to split the text.
# second element(0 or 1): the first element should become   if 0: <text1><key> | <text2>,  if 1: <text1> | <key><text2>


# For Fortran code

rules = [
    {
        "SUBROUTINE " : 1,
    },

    {
        "if " : 1,
        "end if" : 0,
    },

    {
        "\n\n" : 0,
        "<0x0A><0x0A>" : 0,
        "\x0A\x0A" : 0,
    },

    {
        "\n" : 0,
        "<0x0A>" : 0,
        "\x0A" : 0,
    },
]



# For python code

rules = [
    {
        "class " : 1,
    },

    {
        "def " : 1,
        #"void " : 1,
    },

    {
        "if " : 1,
        "end if" : 0,
    },

    {
        "else " : 1,
        "elif " : 1,
    },
    

    {
        "\n\n" : 0,
        "<0x0A><0x0A>" : 0,
        "\x0A\x0A" : 0,
    },

    {
        "\n" : 0,
        "<0x0A>" : 0,
        "\x0A" : 0,
    },
]


# For latex papers or textbooks

rules = [
    {
        "\\\\section*" : 1,
    },

    {
        "\\\\subsection*" : 1,
    },

    {
        "\\\\begin{center}" : 1,
        "\\\\end{gather*}" : 0,
        "\\\\end{align*}" : 0,   
        "\\\\end{equation*}" : 0,
        "\\\\end{enumerate}" : 0,
    },

    {
        "\n\n" : 0,
        "<0x0A><0x0A>" : 0,
        "\x0A\x0A" : 0,
    },

    {
        "\n" : 0,
        "<0x0A>" : 0,
        "\x0A" : 0,
    },
]


# Chat

In [None]:
from SEIMEI import SEIMEI
import asyncio

processed_path = "./processed"  # input path same as save_path you used in Preparation
expert_class_names = ["Answer", "CheckInf", "MetaSurvey"] # "StructureAnalysis", "ChunkSurvey", "FileSurvey", "MetaSurvey"]
se_restrictions = ["MetaSurvey"]  # search engine only hits classes in this list usually (except when adding expert_restriction in kwargs)
expert_module_names = ["Experts.Code.Modify"]

seimei = SEIMEI(
    processed_path = processed_path,
    expert_class_names = expert_class_names,
    expert_module_names = expert_module_names,
    se_restrictions = se_restrictions,
    max_inference_time = 300,
    tensor_parallel_size = 1,
)


In [None]:
original_question = "How to implement a new equilibrium state called Miller equilibrium into gyro-kinetic vlasov simulation?"
final_answer = await seimei.get_answer(query = original_question) # return final answer

print()
print()
print(final_answer)

In [None]:
from SEIMEI import Log
Log().show()

In [None]:
# for debug
query = "How to implement a new equilibrium state called Miller equilibrium into gyro-kinetic vlasov simulation?"

outputs = seimei.search(query, topk = 50)

import json
with open(f"/workspace/processed/gkv-code/chunks.json") as json_file: chunks = json.load(json_file)
with open(f"/workspace/processed/gkv-code/file_paths.json") as json_file: file_paths = json.load(json_file)

print(outputs)

for (expert, id) in outputs:
    print()
    print(f"--- chunk id {id} ---")
    print(f"file_path: {file_paths[id]}")
    print()
    print(seimei.job_keys[id])
    print()
#print(len(seimei.infs))
#print(seimei.get_num_tokens(seimei.infs[1]["inf"]))
#print(seimei.infs[3]["inf"])

## Kaggle AIMO Test

### Test1

In [1]:
from SEIMEI import SEIMEI
import asyncio

database_name = "gkv-code"
job_classes = ["SearchJob", "StepInference", "SuggestMethod", "EvaluateAnswer", "MakeAnswer", "CheckAnswer2", "SelfCorrection", "GiveHint"]
seimei = SEIMEI(database_name, job_classes)

In [None]:

original_question = "Find the three-digit number n such that writing any other three-digit number 10^2024 times in a row and 10^2024 + 2 times in a row results in two numbers divisible by n."

correct_answer = """Let M = 10^1024. Let a be any three-digit number. Writing M copies of a in a row results
in a number X where
X =a×100100100...1001001
and there are M copies of the digit one in the long number. If instead we wrote M + 2 copies of a in a row, the resulting number would be 106X + 1001a. We use the notation (u, v) to denote the greatest common divisor of two integers u and v which are not both 0.
We apply Euclid’s algorithm so
((106X + 1001a), X) = (1001a, X).
It is therefore a necessary condition that our three-digit number n should divide (1001a,X) for all three-digit numbers a. By considering a = 100 and a = 101, we see that any candidate for n must divide 1001 × 101 − 1001 × 100 = 1001. Moreover, if n is a divisor of 1001, then n will divide X because 1001 divides 10010010010 . . . 01001001 which is
1001 × 10000010000010 . . . 01000001.
The second factor involves M/2 copies of the digit one. Such an n will also divide 106X + 1001a.
Thus it is a necessary and sufficient condition for n to satisfy the conditions of the problem that n be a three-digit divisor of 1001 (= 7 × 11 × 13). There is a unique such number: 143.
"""

await seimei.get_answer(query = original_question, correct_answer = correct_answer) # return final answer

In [None]:
print()
print("hint")
print(SEIMEI.correct_answers[0]["hint"])
print()
print("pre_answer")
print(SEIMEI.correct_answers[0]["pre_answer"])
print()
print("answer")
print(SEIMEI.correct_answers[0]["answer"])

### Test2

In [None]:
from SEIMEI import SEIMEI
import asyncio

expert_class_names = ["MakeStrategy", "EvaluateAnswer", "MakeAnswer2"]
expert_module_names = ["Experts.AIMO2.RyuSystem"]
seimei = SEIMEI(expert_module_names = expert_module_names, expert_class_names = expert_class_names)


In [None]:

original_question = "Find the three-digit number n such that writing any other three-digit number 10^2024 times in a row and 10^2024 + 2 times in a row results in two numbers divisible by n."

correct_answer = """Let M = 10^1024. Let a be any three-digit number. Writing M copies of a in a row results
in a number X where
X =a×100100100...1001001
and there are M copies of the digit one in the long number. If instead we wrote M + 2 copies of a in a row, the resulting number would be 106X + 1001a. We use the notation (u, v) to denote the greatest common divisor of two integers u and v which are not both 0.
We apply Euclid’s algorithm so
((106X + 1001a), X) = (1001a, X).
It is therefore a necessary condition that our three-digit number n should divide (1001a,X) for all three-digit numbers a. By considering a = 100 and a = 101, we see that any candidate for n must divide 1001 × 101 − 1001 × 100 = 1001. Moreover, if n is a divisor of 1001, then n will divide X because 1001 divides 10010010010 . . . 01001001 which is
1001 × 10000010000010 . . . 01000001.
The second factor involves M/2 copies of the digit one. Such an n will also divide 106X + 1001a.
Thus it is a necessary and sufficient condition for n to satisfy the conditions of the problem that n be a three-digit divisor of 1001 (= 7 × 11 × 13). There is a unique such number: 143.
"""

await seimei.get_answer(query = original_question, correct_answer = correct_answer) # return final answer


In [None]:
from SEIMEI import Log
Log().show()

#### Log system test

In [None]:

import ipywidgets as widgets
from IPython.display import display
import json


class Log:
    
    def __init__(self):
        self.log_dict_ids = []
        self.selected_id = 0
        
        with open("log.json") as json_file:
            self.logs = json.load(json_file)
        self.all_log_dict = self.logs[-1]
        
        self.log_dict = self.all_log_dict


    def get_log_dict_text(self):
        
        text = "\n<pre><span style='color:black;'>" + self.log_dict["expert_class_name"] + "\n"
    
        for i in range(len(self.log_dict["called_experts"])):
            if i == self.selected_id:
                text += "<span style='color:green;'>    " + self.log_dict["called_experts"][i]["expert_class_name"] + "</span>\n"
                for j in range(len(self.log_dict["called_experts"][i]["called_experts"])):
                    text += "       " + self.log_dict["called_experts"][i]["called_experts"][j]["expert_class_name"] + "\n"
            else:
                text += "    " + self.log_dict["called_experts"][i]["expert_class_name"] + "\n"
            
        text += "</span></pre>"
    
        return text


    def get_arg_return_text(self):
        text = f"""<pre>\n\n--- args ---\n{self.json_show(self.log_dict["called_experts"][self.selected_id]["args"], 0)}\n\n"""
        text += f"""--- return ---\n{self.json_show(self.log_dict["called_experts"][self.selected_id]["return"], 0)}</pre>"""
        text = text.replace("<s>","")
        return text

    # recursive function
    def json_show(self, element, num_column):
        
        text = ""
        
        if type(element) == list:
            text += " "*3*num_column + "[\n"
            for i, e in enumerate(element):
                text += " "*3*(num_column+1) + f"- {i+1} -\n"
                text += self.json_show(e, num_column+1) + "\n"
            text += " "*3*num_column + "]\n"
                
        elif type(element) == dict:
            for i, key in enumerate(element):
                text += " "*3*num_column + f"- {i+1} -" + key + " :\n"
                text += self.json_show(element[key], num_column+1) + "\n"

        elif type(element) == str or type(element) == int or type(element) == bool or element == None:
            text += " "*3*num_column + str(element) + "\n"

        else:
            raise Exception("element must be list, dict, str or int")

        return text
        
    # Create a GridBox
    def show(self):

        text_display = widgets.HTML(value=self.get_log_dict_text())
        
        # Define functions to handle button clicks
        def on_up_button_clicked(b):
            if self.selected_id > 0:
                self.selected_id -= 1
            text_display.value = self.get_log_dict_text()
        
        def on_down_button_clicked(b):
            if self.selected_id < len(self.log_dict["called_experts"]) - 1:
                self.selected_id += 1
            text_display.value = self.get_log_dict_text()
        
        def on_left_button_clicked(b):
            if self.log_dict_ids!=[]: self.log_dict_ids.pop()
            self.log_dict = self.all_log_dict
            for id in self.log_dict_ids:
                self.log_dict = self.log_dict["called_experts"][id]
            text_display.value = self.get_log_dict_text()
        
        def on_right_button_clicked(b):
            if self.log_dict["called_experts"] != []:
                self.log_dict = self.log_dict["called_experts"][self.selected_id]
                self.log_dict_ids.append(self.selected_id)
                self.selected_id = 0
            text_display.value = self.get_log_dict_text()
        
        def on_center_button_clicked(b):
            text = self.get_log_dict_text()
            text += self.get_arg_return_text()
            text_display.value = text
        
        def on_left_up_button_clicked(b):
            pass
    
        up_button = widgets.Button(description='Up')
        down_button = widgets.Button(description='Down')
        left_button = widgets.Button(description='Back')
        right_button = widgets.Button(description='Next')
        center_button = widgets.Button(description='Select')
        left_up_button = widgets.Button(description='Menu')
    
        # Attach functions to button click events
        up_button.on_click(on_up_button_clicked)
        down_button.on_click(on_down_button_clicked)
        left_button.on_click(on_left_button_clicked)
        right_button.on_click(on_right_button_clicked)
        center_button.on_click(on_center_button_clicked)
        left_up_button.on_click(on_left_up_button_clicked)
    
        buttons = [
            left_up_button,
            up_button,
            widgets.Button(description=''),
            left_button,
            center_button,
            right_button,
            widgets.Button(description=''),
            down_button,
            widgets.Button(description=''),
        ]
        
        grid = widgets.GridBox(children=buttons,
                               layout=widgets.Layout(grid_template_columns='repeat(3, 150px)',
                                                     grid_template_rows='repeat(3, 30px)',
                                                     grid_gap='10px'))
    
        # Display the GridBox
        display(grid, text_display)

Log().show()


## GKV test chat

### Basic Questions

In [1]:
from SEIMEI import SEIMEI

database_name = "gkv-code"
max_llm_iter = 10
job_classes = ["SearchJob", "Answer", "ChunkSurvey", "FileSurvey", "MetaSurvey", "CheckInf", "StructureAnalysis"]

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "How to change the parameters for simulating by gkv-code? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "How to run the entire simulation code? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "Where should I define the file name of namelist of entire simulation? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer) # hullucination

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "How to input the number of MPI process? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)  # could mention sub.q but not about header

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "I wanna add a particle which has different mass. How to change the namelist in this case? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "I wanna run nonlinear gyro kinetic vlasov simulation. Which part of the gkv code and how should I modify? Start answering this question with figuring out what folder or file is related to user question."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)  
# it seems to be a good answer, but it didn't mention name_list because there is no info about it in the database and also 9b-llm isn't good enough to speculate namelist is somewhere in the database. In this case llm should notice some parameters in README_for_namelist are not in headers. 
# to achive this inference, job to investigate further should be required.

### Advanced Questions

In [None]:
from SEIMEI import SEIMEI
import asyncio

database_name = "gkv-code"
expert_class_names = ["Answer", "CheckInf", "MetaSurvey"] # "StructureAnalysis", "ChunkSurvey", "FileSurvey", "MetaSurvey"]
se_restrictions = ["MetaSurvey"]  # search engine only hits classes in this list usually (except when adding expert_restriction in kwargs)
expert_module_names = ["Experts.Code.Modify"]

seimei = SEIMEI(
    database_name = database_name,
    expert_class_names = expert_class_names,
    expert_module_names = expert_module_names,
    se_restrictions = se_restrictions,
    max_inference_time = 300,
    tensor_parallel_size = 1,
)


In [None]:
original_question = "How to implement a new equilibrium state called Miller equilibrium into gyro-kinetic vlasov simulation?"
final_answer = await seimei.get_answer(query = original_question) # return final answer

print()
print()
print(final_answer)

In [None]:
from SEIMEI import Log
Log().show()

In [None]:
# for debug
query = "How to implement a new equilibrium state called Miller equilibrium into gyro-kinetic vlasov simulation?"

outputs = seimei.search(query, topk = 50)

import json
with open(f"/workspace/processed/gkv-code/chunks.json") as json_file: chunks = json.load(json_file)
with open(f"/workspace/processed/gkv-code/file_paths.json") as json_file: file_paths = json.load(json_file)

print(outputs)

for (expert, id) in outputs:
    print()
    print(f"--- chunk id {id} ---")
    print(f"file_path: {file_paths[id]}")
    print()
    print(seimei.job_keys[id])
    print()
#print(len(seimei.infs))
#print(seimei.get_num_tokens(seimei.infs[1]["inf"]))
#print(seimei.infs[3]["inf"])

In [None]:
original_question = "How a variable Anum in name_list is used in the simulation?"
final_answer = await seimei.get_answer(query = original_question) # return final answer
print(final_answer)

In [None]:
from SEIMEI import Log
Log().show()

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "How a variable Anum in name_list is used in the simulation? Give me all the relevant calculation code."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)

In [None]:
#for inf in seimei.infs:
#    print(inf["inf"])
import json
database_name = "gkv-code"

with open(f"../processed/{database_name}/chunks.json") as json_file:
    chunks = json.load(json_file)

print("---------")
print(chunks[839])
print("---------")
print(chunks[690])

In [None]:
seimei = SEIMEI(database_name, max_llm_iter, job_classes)
original_question = "I wanna know how does the simulation code run step by step. Analyze the structure of code and figure out the flow of the simulation."
final_answer = seimei.get_answer(original_question) # return final answer
print(final_answer)

## Transformers test chat

In [3]:
database_name = "transformers"
max_more = 5
max_dispose = 10

In [None]:
# model load
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

device = "cuda" if torch.cuda.is_available else "cpu"
emb_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1").to(device)

"""
# Model load for japanese
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

assert transformers.__version__ >= "4.34.1"

model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat", device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
"""

model_id = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id, 
    padding_side="left",
    add_eos_token=False,
    add_bos_token=False,)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

model = AutoModelForCausalLM.from_pretrained(model_id).to(device)


# for json enforcer
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
from transformers import pipeline

hf_pipeline = pipeline('text-generation', model=model, tokenizer = tokenizer, device = 0)




In [None]:
import json
import os

device = "cuda" if torch.cuda.is_available else "cpu"

class FRAG:
    def __init__(self, database_name, max_more, max_dispose):
        # path for making function-explanation
        self.path_call = f"processed/{database_name}/calls.json"
        self.path_def = f"processed/{database_name}/defs.json"
        self.file_paths = f"processed/{database_name}/file_paths.json"

        self.max_more = max_more
        self.max_dispose = max_dispose

    
    def get_answer(self, original_question):
        generate = False
        next_question = original_question
        self.code_mem_list = []
        self.keep_id_list = []
        self.dispose_list = []

        i = 0
        while generate == False and i < self.max_more:
            j=0
            i += 1
            keep = False
            
            while keep == False and j < self.max_dispose:
                j += 1
                infs, id = self.get_infs(next_question, self.dispose_list, self.keep_id_list)
                keep, thought = self.LLM1(next_question, infs[0], id)
                
                if keep:
                    self.keep_id_list.append(id)
                    break
                else:
                    self.dispose_list.append(id)
            
            generate, next_question, thought = self.LLM2(original_question, next_question, self.code_mem_list, self.keep_id_list)
            
            code_mem, relation = self.SUMLLM(original_question, next_question, infs[0], id)
            
            self.code_mem_list.append(code_mem)

        answer = self.GENELLM(original_question, self.code_mem_list, self.keep_id_list)

        return answer


    
    def LLM1(self, question, code_inf, code_id):
        func_des = self.get_func_description(code_id)

        # for restricting answer to be json 
        class LLM1Format(BaseModel):
            thought: str
            keep: bool

        parser = JsonSchemaParser(LLM1Format.schema())
        prefix_function = build_transformers_prefix_allowed_tokens_fn(hf_pipeline.tokenizer, parser)

        prompt = f"""
[INST]<<SYS>>You are an excellent commander. Based on the code from system and the question provided by the user, you decide the next action and answer it in json text. Use the following criteria to make your decision:

- Include {{"keep":true}} in json text if the code provided is related, even partially, to the user's question. This indicates that the code, while possibly incomplete or not entirely covering all aspects, still has relevance and may contain useful elements or logic that pertains to the question.
- Include {{"keep":false}} in json text if the code provided is completely unrelated to the user's question. This means the code does not contribute in any way to answering the question and should be disregarded.

<</SYS>>

User question: {question}

<<SYS>>
Function description:
{func_des}

Code from system:
```
{code_inf}
```

Firstly, you need to share your opinion about the reason for your decision, then you need to share your decision. Use the json format below:
{{
    "thought": (Explain whether the given code is necessary to answer the user's question, and how it relates, even if partially.),
    "keep": (Choose from "true" or "false".)
}}
<</SYS>>
[/INST]"""
        
        output = self.get_output(prompt, max_new_tokens = 1000, prefix_function = prefix_function)
        processed, json_mode = self.text2json(output)

        if json_mode:
            keep = processed["keep"]
            thought = processed["thought"]
        else:
            keep = True if "True" in processed else False
            thought = processed
            
        return keep, thought
        


    def LLM2(self, original_question, next_question, code_mem_list, keep_id_list):
        combined_code = self.combine_codes(code_mem_list,keep_id_list)

        # for restricting answer to be json 
        class LLM2Format(BaseModel):
            thought: str
            generate: bool
            next_question: str

        parser = JsonSchemaParser(LLM2Format.schema())
        prefix_function = build_transformers_prefix_allowed_tokens_fn(hf_pipeline.tokenizer, parser)

        prompt = f"""[INST]<<SYS>>You are an excellent commander. Based on the code from system and question provided by the user, you decide the next action and answer it in json text. Use the following criteria to make your decision:

- Include {{"generate":false}} in json text if the code is related but not comprehensive enough to answer the question. This means some elements are missing, which are necessary to complete the answer or to cover all aspects of the question.
- Include {{"generate":true}} in json text if the code provided fully satisfies the requirements to answer the user's question comprehensively.

<</SYS>>

User question: {original_question}
Last search question:{next_question}

<<SYS>>
#Pieces of code from system:
{combined_code}

Firstly, you need to share your opinion about that the provided code is sufficient or insufficient, then you need to share your decision. Additionally, you must formulate a follow-up question to collect the missing information necessary to complete the code. Use the json format below:
{{
    "thought": (Explain why the provided code is sufficient or insufficient),
    "generate": (Choose from 'true' or 'false'),
    "next_question": (Formulate a question to help gather the missing or additional code required)
}}
<</SYS>>
[/INST]"""
        
        output = self.get_output(prompt, max_new_tokens = 3500, prefix_function = prefix_function)
        processed, json_mode = self.text2json(output)

        if json_mode:
            generate = processed["generate"]
            next_question = processed["next_question"]
            thought = processed["thought"]
        else:
            generate = True if "True" in processed else False
            next_question = processed
            thought = processed
            
        return generate, next_question, thought



    def SUMLLM(self, original_question, next_question, code_inf, id):
        func_des = self.get_func_description(id)
        add_code, folder_des = self.get_address_folder(id)

        # for restricting answer to be json 
        class SUMLLMFormat(BaseModel):
            code: str
            relation: str

        parser = JsonSchemaParser(SUMLLMFormat.schema())
        prefix_function = build_transformers_prefix_allowed_tokens_fn(hf_pipeline.tokenizer, parser)
        
        prompt = f"""[INST]<<SYS>>
You are a skilled programmer proficient in explaining code in json text. Your primary task is to identify and extract the crucial parts of code based on the pairings of user-submitted questions and corresponding code snippets. While the code often relates to the users questions, not all parts may be necessary to answer these questions. Users are specifically interested in those portions of the code that are most relevant to their inquiries. Therefore, you must focus solely on extracting these pertinent sections without modifying or editorializing the code. If no relevant code sections are found, output "Nothing".
<</SYS>>

User question:
{original_question}

<<SYS>>
Question for Searching the code below:{next_question}
#Code from system:

##Code Overview Set
{add_code}

{folder_des}

{func_des}

Code:
```
{code_inf}
```

You are required to extract the significant sections from the provided code that are essential for answering the user's question and return it in json text. Highlight these sections and explain their relevance to the question without altering the original code format or content. Please follow the json format below:

{{
    "thought": (Quick explanation of the answer you will give in the folloing.),
    "code": (The critical parts of the code necessary to answer the user's question. Do not modify or editorialize the code. If no sections of the code are critical, you should explicitly output "Nothing".),
    "relation": (Tell me relation between the code and Users question. If no sections of the code are related, you should explicitly output "Nothing")
}}

<</SYS>>
[/INST]"""
        
        output = self.get_output(prompt, max_new_tokens = 2500, prefix_function = prefix_function)
        processed, json_mode = self.text2json(output)

        if json_mode:
            code = processed["code"]
            relation = processed["relation"]
        else:
            code = processed
            relation = processed
            
        return code, relation


    
    def GENELLM(self, original_question, code_mem_list, keep_id_list):
        combined_code = self.combine_codes(code_mem_list,keep_id_list)
        
        prompt = f"""[INST]<<SYS>>
You are an excellent programmer and are adept at explaining code. You will be provided with one or more pieces of code along with corresponding questions from systems. The provided code is selected from a larger codebase specifically to enable you to answer these questions. Your task is to answer the user’s questions as thoroughly and clearly as possible, demonstrating your understanding and ability to communicate key coding concepts.

<</SYS>>

User question:
{original_question}

<<SYS>>
#Pieces of code from system:

{combined_code}
<</SYS>>[/INST]"""
        
        return self.get_output(prompt, max_new_tokens = 2500)

    
    def get_output(self, prompt, max_new_tokens = 1000, prefix_function = None):
        print()
        print("=== input ===")
        print(prompt)
        
        if prefix_function == None:
            print()
            print("=== normal output ===")
            input_ids = tokenizer(prompt, return_tensors="pt").to(device)
            output_ids = model.generate(
                **input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.8,
                streamer=streamer
            )
            output = tokenizer.decode(output_ids[0][len(input_ids[0]):], skip_special_tokens = True)
            return output

        else:
            print()
            print("=== json output ===")

            #hf_pipeline.max_length = max_new_tokens
            output_dict = hf_pipeline(prompt, max_new_tokens = max_new_tokens, prefix_allowed_tokens_fn = prefix_function)
            print(output_dict[0]['generated_text'][len(prompt):])
            
            return output_dict[0]['generated_text'][len(prompt):]
        

    def text2json(self, text):
        try:
            output = json.loads(text)
            return output, True
    
        except:
            print()
            print("Failed to get json type object")
            return text, False

    
    def get_infs(self, question, disposed_id_list, keep_id_list):
        # 問題文に基づいて検索する
        q_embs = torch.tensor(emb_model.encode(question)).to(device)
        inf_embs = torch.load(f"processed/{database_name}/summary_embs.pt").to(device)
        
        with open(f"processed/{database_name}/chunks.json") as json_file:
            chunks = json.load(json_file)
    
        relevance = torch.matmul(q_embs, inf_embs.T) 
        
        # Top-3 のIDを取得
        values, inf_ids = torch.topk(relevance, k=3, dim=0)  # dim=1 で行ごとのTop-Kを取得
        
        infs = []
        selected_id = None
        for id in inf_ids:
            if id.item() not in disposed_id_list:
                if id.item() not in keep_id_list:
                    selected_id = id.item()
                    infs.append(chunks[selected_id])
                    break  # 最初に見つかった適切なIDで終了
    
        if selected_id == None:
            values, inf_ids = torch.topk(relevance, k=relevance.shape[0], dim=0)
            for id in inf_ids:
                if id.item() not in disposed_id_list:
                    if id.item() not in keep_id_list:
                        selected_id = id.item()
                        infs.append(chunks[selected_id])
                        break  # 最初に見つかった適切なIDで終了
                
        return infs, selected_id
    
    
    
    def get_func_description(self, id):
        #initialize func_list
        func_list = []
        func_set = set()
        # open calls folder
        with open(self.path_call, 'r') as file:
            data = json.load(file)
            functions = data[id]
        for key1, value1 in functions.items():
            # open defs folder
            with open(self.path_def, 'r') as file:
                 defs_data = json.load(file)
            
            for def_item in defs_data:
                for key2, value2 in def_item.items():
                    if key2 == key1:
                        if key2 not in func_set:
                            func_set.add(key2)
                            func_list.append(f"{key2}:{value2}")
    
        if not func_list:
            return ""
        
        formatted_descriptions = [
            f"- {desc.split(':')[0]}: {desc.split(':')[1].strip()}."
            for desc in func_list
        ]
    
        # 最終的な説明文を生成
        description_of_functions = "Description of the functions used in the code below:\n" + "\n".join(formatted_descriptions)
        
        return description_of_functions
    
    
    
    def get_address_folder(self, id):
        # get file_paths from id
        with open(self.file_paths, 'r') as file:
            data = json.load(file)
            file_path = data[id]
        f_name_list, f_summary_list = self.get_path_summaries(file_path, database_name)
        address_code = self.generate_tree_structure(f_name_list)
        formatted_descripitions = self.format_descriptions(f_name_list, f_summary_list)
        return address_code, formatted_descripitions
    
    
    
    def get_path_summaries(self, file_path, dataset_name):
        file_path_json = f"processed/{database_name}/f_summary.json"
        with open(file_path_json) as json_file:
            f_summary = json.load(json_file)
    
        f_name_list = []
        f_summary_list = []
        while "/" in file_path: # not run when path == data where summary of dataset_name folder is already added to the list
            f_name_list.insert(0, os.path.basename(file_path))
            f_summary_list.insert(0, f_summary[file_path])
            file_path = os.path.dirname(file_path)
            
        return f_name_list, f_summary_list
    
    
    
    def generate_tree_structure(self, folders_files):
        # 基本のパスを設定
        base = "The address of code below:{\n"
        # 各フォルダやファイルに対してツリーノードを追加
        indent = ""
        for i, item in enumerate(folders_files):
            if i < len(folders_files) - 1:  # 最後の要素でない場合
                base += f"{indent}|─ {item}/\n"
                indent += "|   "  # インデントを追加
            else:  # 最後の要素の場合
                base += f"{indent}|─ {item}/\n"
        base += "}"
        return base

    
    # フォルダとファイルの説明をフォーマットする関数
    def format_descriptions(self, f_name_list, f_summary_list):
        formatted_text = "Folder and file descriptions:\n"
        for name, desc in zip(f_name_list, f_summary_list):
            formatted_text += f"  - name: {name}\n    description: {desc}\n"
        return formatted_text
        
        
    
    def get_prompt(self, q, inf_list):
        prompt = q + "\nCode:"
        for inf in inf_list:
            prompt += "\n```" + inf + "```"
            
        return prompt
    
    def combine_codes(self, code_mem_list,keep_id_list):
        combined_code = ""
        for id, code in zip(keep_id_list, code_mem_list):
            set = "##Code Overview Set"
            add_code, folder_des = self.get_address_folder(id)
            func_des = self.get_func_description(id)
            set += f"\n{add_code}\n\n{folder_des}\n\n{func_des}\n\n```\n{code}\n```\n\n"
            combined_code += set
        return combined_code

    def get_new_question(self, output):
        # 'Next question:' または 'Next question :' のインデックスを取得
        next_question_index = output.find('Next question:')
        if next_question_index != -1:
            # 'Next question:'の後の空白をスキップ
            question_start_index = next_question_index + len('Next question:')
            while output[question_start_index] == ' ':
                question_start_index += 1
            
            # 質問文を取得し、不要なタグを削除
            question_end_index = output.find('</s>', question_start_index)
            if question_end_index == -1:
                question_end_index = None  # タグがない場合は文字列の最後までが質問
            question = output[question_start_index:question_end_index].strip()
        else:
            question = "Next question not found in input"
        
        return question


# jsonでerrorが出た時に、もっといい方法があると思う（LLMのoutputをrelationなどにわけず柔軟に対応できたらもっといい）

In [None]:
original_question = """
where is a folder to define input of the pretrained_model?
"""
frag = FRAG(database_name, max_more, max_dispose)
frag.get_answer(original_question)

# この質問に関しては検索エンジンの性能とenvironmentの説明をもっと詳しくしていくだけ

In [None]:
original_question = """
what's the difference between mistral and mixtral?
"""
frag = FRAG(database_name, max_more, max_dispose)
frag.get_answer(original_question)

#

In [None]:
original_question = """
Explain the structure of Trainer class.
"""
frag = FRAG(database_name, max_more, max_dispose)
frag.get_answer(original_question)

In [None]:
original_question = """
I wanna modify mistral_modeling.py file so that I can designate layers of hidden state and only those outputs are returned. How to modify the code?
"""
frag = FRAG(database_name, max_more, max_dispose)
frag.get_answer(original_question)

In [None]:
original_question = """
I wanna add LeakyReLU function into transformers source code. Tell me where to insert the function and the code to be inserted."""
frag = FRAG(database_name, max_more, max_dispose)
frag.get_answer(original_question)