In [191]:
from langchain.prompts import PromptTemplate

from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser

In [192]:
import yaml

with open('secrets.yml', 'r') as f:
    secrets = yaml.load(f, Loader=yaml.SafeLoader)

In [193]:
from langchain_groq import ChatGroq
import os

os.environ["GROQ_API_KEY"] = secrets['groq'][0]
chat_model = ChatGroq(
            model="llama3-70b-8192",
        )
json_model = ChatGroq(
            model="llama3-70b-8192",
        ).bind(response_format={"type": "json_object"})

## Model modifier

In [194]:
#%pip install openpyxl

In [195]:
import pandas as pd
import numpy as np

tmap = pd.ExcelFile('Models/DEModel.xlsx')
df = pd.read_excel(tmap,"ConversionSubProcess")

conversion_processes = np.asarray(df.iloc[:,0].dropna())
mask = np.where(conversion_processes != 'DEBUG')
conversion_processes = conversion_processes[mask]
parameters = np.asarray(df.columns[4:])

In [196]:
cs = np.asarray(df.iloc[:,0:4].dropna())
mask = np.where(cs[:,0] != 'DEBUG')
cs = cs[mask]
conversion_subprocesses = np.empty((len(cs),1),dtype=object)

for i in range(len(cs)):
    conversion_subprocesses[i] = f'{cs[i,0]}@{cs[i,1]}@{cs[i,2]}@{cs[i,3]}'

In [232]:
params_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a specialist at identifying the correct conversion subprocess and the correct parameter
    selected by the user in his QUERY. \n
    
    As a context, you will receive two data arrays. PARAMS provides you the name of the parameters
    available to be selected. CONVERSION_SUBPROCESSES provides you the combination of 'cp' (conversion process name),
    'cin' (commodity in), 'cout' (commodity out) and 'scen' (scenario) in the format 'cp@cin@cout@scen'.\n
    
    Your goal is to output a JSON object containing four keys: 'param', 'value', 'cs_list', 'total_cs'.
    'param' must receive the name of the selected parameter;
    'value' is the new value selected by the user;
    'cs_list' is a list with all matching conversion subprocesses (idealy only one if possible);
    'total_cs' is the number of conversion subprocesses that possibly match the user's request; \n
    
    NEVER MAKE UP DATA, USE ONLY DATA FROM THE GIVEN LIST. If you can't find any match to the 'cp' name, fill  the
    field 'cp' with 'NOT_FOUND'. \n
    
    The field 'value' only accepts numeric input. \n

    <|eot_id|><|start_header_id|>user<|end_header_id|>
    QUERY: {query} \n
    PARAMS: {params} \n
    CONVERSION_SUBPROCESSES: {CSs} \n
    Answer:
    <|eot_id|>
    <|start_header_id|>assistant<|end_header_id|>
    """,
    input_variables=["query","params","CSs"],
)

params_chain = params_prompt | json_model | JsonOutputParser()

In [234]:
query = 'Modify the operational cost of biomass CHPs with help_biomass_chp as input and electricity as output to 1000'

params_chain.invoke({"query": query, "params": parameters, "CSs": conversion_subprocesses})

{'param': 'opex_cost_energy',
 'value': 1000,
 'cs_list': ['Biomass_CHP@Biomass@Help_Biomass_CHP@Electricity@Base'],
 'total_cs': 1}

In [208]:
def param_selector(state):

    print("---PARAM SELECTOR---")
    query = state['query']
    num_steps = state['num_steps']
    parameters = state['parameters']
    conversion_subprocesses = state['CSs']
    num_steps += 1

    output = params_chain.invoke({"query": query, "params": parameters, "CSs": conversion_subprocesses})
    
    return {"output": output,
            "num_steps": num_steps}

In [202]:
def confirm_selection(state):
    
    print('---CONFIRM SELECTION---')
    output = state['output']
    
    print('The following selection of conversion subprocess and parameter was found:\n')
    print(f"CP: {output['cp']} CIN: {output['cin']} COUT: {output['cout']} SCEN: {output['scen']} PARAMETER: {output['param']}\n")
    
    answer = input('Is that correct? (Y or N)\n')
    
    return {"final_answer": 'The selected parameter was modified!'} if answer == 'Y' else {"final_answer": 'Aborted'}

In [223]:
from tabulate import tabulate

def select_from_found(state):
    
    print('---SELECT CORRECT CS---')
    output = state['output']
    CSs = output['found_CSs']
    
    data = []
    print('The following conversion subprocesses were found:\n')
    for i in range(len(CSs)):
        elements = CSs[i].split('@')
        data.append([i,elements[0],elements[1],elements[2],elements[3]])
        
    print (tabulate(data, headers=["Index", "CP", "CIN", "COUT", "Scen"]))
    
    answer = input('Input the number of the correct CS:\n')
    
    return {"final_answer": CSs[int(answer)]}

In [204]:
def print_final(state):
    final = state['final_answer']
    print(f'FINAL ANSWER: {final}')
    
    return None

In [205]:
def param_router(state):
    
    print("---PARAM ROUTER---")
    output = state['output']
    
    if output['found_cps'] == 1:
        return 'confirm_selection'
    elif output['found_cps'] > 1:
        return 'select_from_found'
    else:
        return 'param_not_found'

In [224]:
from langchain.schema import Document
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict
from typing import List

### State

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        initial_query: user input
        next_query: partial query generated by the agent
        num_steps: number of steps
        selected_tool: name of the selected tool
        rag_questions: questions used for retrieval
        tool_parameters: parameters to be used by tools
        context: list of context generated for the query
        complete_data: indicates completeness of data
        final_answer: LLM generation
    """
    query : str
    parameters: List[str]
    CSs: List[str]
    num_steps : int
    user_answer: str
    output: str
    final_answer : str

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("param_selector", param_selector)
workflow.add_node("confirm_selection", confirm_selection)
workflow.add_node("select_from_found", select_from_found)
workflow.add_node("print_final", print_final)

In [225]:
workflow.set_entry_point("param_selector")
workflow.add_conditional_edges(
    "param_selector",
    param_router,
    {
        "confirm_selection": "confirm_selection",
        "select_from_found": "select_from_found",
        "param_not_found": END,
    }
)


workflow.add_edge("confirm_selection", "print_final")
workflow.add_edge("select_from_found", "print_final")
workflow.add_edge("print_final", END)

In [226]:
app = workflow.compile()

In [227]:
query = 'Modify the operational cost of biomass CHPs to 1000'

inputs = {"query": query, "parameters": parameters, "CSs": conversion_subprocesses, "num_steps": 0,}
for output in app.stream(inputs, {"recursion_limit": 50}):
    for key, value in output.items():
        print(f"Finished running: {key}:")

---PARAM SELECTOR---
---PARAM ROUTER---
Finished running: param_selector:
---SELECT CORRECT CS---
The following conversion subprocesses were found:

  Index  CP           CIN               COUT                Scen
-------  -----------  ----------------  ------------------  ------
      0  Biomass_CHP  Biomass           Help_Biomass_CHP    Base
      1  Biomass_CHP  Help_Biomass_CHP  Dummy               Base
      2  Biomass_CHP  Help_Biomass_CHP  Electricity         Base
      3  Biomass_CHP  Help_Biomass_CHP  Industrial_Heat_LT  Base
Finished running: select_from_found:
FINAL ANSWER: Biomass_CHP@Biomass@Help_Biomass_CHP@Base
