In [67]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

###################################
#           LIBRARIES             #
###################################

import os
import json
import csv
import time
import uuid
import faiss
import re
import numpy as np
import pandas as pd
from pybliometrics.scopus import ScopusSearch, AbstractRetrieval
from sentence_transformers import SentenceTransformer
from bs4 import BeautifulSoup
import requests
import xml.etree.ElementTree as ET  # For parsing the Ecore file
from codecarbon import EmissionsTracker

# LangChain and related modules
from langchain.prompts.chat import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain.llms import Ollama
from langchain_experimental.llms.ollama_functions import OllamaFunctions
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools.base import Tool
from typing import Callable, List, Dict, Any, TypedDict
import json
import re

###################################
#         GENERATION CHAIN        #
###################################
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_llm_cache
from langchain.cache import InMemoryCache

set_llm_cache(InMemoryCache())

###################################
#         CONFIGURATION           #
###################################

# Define configuration paths and constants
CONFIG_FILE = "config/llm_config_mistral.json"
MODELS_FILE = "config/llm_models.json"
CONFIG_RAG_FILE = "config/llm_config_openai_rag.json"
LLM_TYPE = 'Ollama'  # Options: 'Others', 'Ollama'
RAG_CHAT = 'OpenAI'  # Options: 'OpenAI', 'LangChain'
SLR = 'MDTE'

# Define minimum and maximum thresholds for retrieved papers
min_threshold = 2     # Minimum desired number of papers
max_threshold = 2000    # Maximum desired number of papers

###################################
#         UTILITY FUNCTIONS       #
###################################

# Function to load configuration from a JSON file
def load_config(config_file):
    try:
        with open(config_file, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        print(f"Configuration file {config_file} not found.")
        return {}
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        return {}

# Function to load file content
def load_file_content(file_path):
    try:
        with open(file_path, 'r') as file:
            return file.read()
    except FileNotFoundError:
        print(f"Error: File {file_path} not found.")
        return ""

# Function to save content to a file
def save_to_file(file_path, content):
    with open(file_path, 'w') as file:
        file.write(content)

# Function to save metadata to a file (in JSON format)
def save_metadata(file_path, metadata):
    with open(file_path, 'w') as file:
        json.dump(metadata, file, indent=4)

###################################
#         LLM CONFIGURATION       #
###################################

# Load LLM configuration
config = load_config(CONFIG_FILE)
models_config = load_config(MODELS_FILE)

# Extract LLM parameters from configuration
LLM = config.get("llm")
if not LLM:
    raise ValueError("LLM name must be specified in the configuration file.")

PRICE_PER_INPUT_TOKEN = config.get("price_per_input_token")
PRICE_PER_OUTPUT_TOKEN = config.get("price_per_output_token")
temperature = config.get("temperature")
max_retries = config.get("max_retries")
api_key = config.get("api_keys", {}).get(LLM.lower(), None)
base_url = config.get("base_url")

# Determine LLM type and initialize LLM instance
llm_config = models_config.get(LLM, None)
if llm_config and LLM_TYPE != 'Ollama':
    llm_params = llm_config.get("params", {})
    llm_params["temperature"] = temperature
    llm_params["max_retries"] = max_retries
    llm_params["api_key"] = api_key
    llm_params["base_url"] = base_url

    # Dynamically initialize the LLM class
    llm_class = eval(llm_config["class"])
    llm_LangChain = llm_class(**llm_params)
    model_name = LLM  # Use LLM name as the model name
elif LLM_TYPE == 'Ollama':
    llm_params = llm_config.get("params", {})
    llm_params["temperature"] = temperature
    llm_params["base_url"] = base_url

    llm_class = eval(llm_config["class"])
    llm_LangChain = llm_class(**llm_params)
    model_name = LLM
else:
    raise ValueError(f"Model configuration for '{LLM}' not found in {MODELS_FILE}.")

###################################
#       CONFIGURATION FOLDERS     #
###################################

base_output_dir = f"00-Query/00-Query-Agent-{SLR}-{model_name.lower()}-{temperature}"
base_output_json_dir = f"00-Query/00-Query-Agent-{SLR}-{model_name.lower()}-{temperature}/JSON"

os.makedirs(base_output_dir, exist_ok=True)
os.makedirs(base_output_json_dir, exist_ok=True)

filename = 'results.json'
csv_filename = os.path.join(base_output_dir, "Scopus_Search_results.csv")
csv_abstracts_filename = os.path.join(base_output_dir, "Scopus_AbstractRetrieval_results.csv")

# output_trace_path = os.path.join(base_output_dir, file_name.replace(".hepsy", ".xes"))
# metadata_path = os.path.join(base_output_json_dir, file_name.replace(".hepsy", ".json"))

###################################
#       PROFILING & CODECARBON    #
###################################

###################################
#         GLOBAL PROFILING        #
###################################

# Global list to collect CodeCarbon metrics for each node call (per file)
cc_metrics_for_file = []  # This will be reset for each file

# Global list for overall CodeCarbon summary per file
cc_summary_records = []

# Global list to save profiling data
profiling_records = []

# Profiling Folder
PROFILING_FOLDER = f"00-Query/00-Query-Agent-{SLR}-{model_name.lower()}-{temperature}/JSON"
if not os.path.exists(PROFILING_FOLDER):
    os.makedirs(PROFILING_FOLDER)
PROFILING_CSV_FILE = os.path.join(PROFILING_FOLDER, "profiling.csv")

# CodeCarbon Folder
CODECARBON_FOLDER  = f"00-Query/00-Query-Agent-{SLR}-{model_name.lower()}-{temperature}/JSON"
if not os.path.exists(CODECARBON_FOLDER ):
    os.makedirs(CODECARBON_FOLDER )
PROFILING_CSV_FILE = os.path.join(PROFILING_FOLDER, "codecarbon_summary.csv")

# Folder to save evaluation results per file
EVALUATION_FOLDER = f"00-Query/00-Query-Agent-{SLR}-{model_name.lower()}-{temperature}/JSON"
if not os.path.exists(EVALUATION_FOLDER):
    os.makedirs(EVALUATION_FOLDER)

###################################
#      TIMING NODE PROFILING      #
###################################

def timing_profile_node(func):
    """
    Decorator to profile a node function.
    Appends a record with the node name and its execution time (in seconds) to profiling_records.
    """
    def wrapper(state, *args, **kwargs):
        start = time.time()
        result = func(state, *args, **kwargs)
        end = time.time()
        elapsed = end - start
        profiling_records.append({"node": func.__name__, "execution_time": elapsed})
        print(f"[Profiling] {func.__name__} took {elapsed:.4f} seconds")
        return result
    return wrapper

###################################
#    CODECARBON NODE DECORATOR    #
###################################

# os.environ["CODECARBON_API_KEY"] = "CODECARBON_API_KEY"
# os.environ["CODECARBON_API_URL"] = "https://api.codecarbon.io"
# os.environ["CODECARBON_EXPERIMENT_ID"] = "UUID"

def cc_profile_node(func):
    """
    Decorator that wraps a node function with CodeCarbon tracking.
    It starts a tracker before calling the node and stops it right after.
    The resulting metrics are appended to the global cc_metrics_for_file list.
    """
    def wrapper(state, *args, **kwargs):
        # Create a CodeCarbon tracker for this node
        tracker = EmissionsTracker(
            project_name=f"cc_{func.__name__}",
            measure_power_secs=1,
            output_dir=CODECARBON_FOLDER,  # You can adjust output_dir as needed (".")
            allow_multiple_runs=True
            # api_call_interval=4,
            # experiment_id=experiment_id,
            # save_to_api=True
        )
        tracker.start()
        result = func(state, *args, **kwargs)
        emissions = tracker.stop()
        # Try to extract detailed metrics if available (from the internal attribute)
        if hasattr(tracker, "_final_emissions_data"):
            metrics = tracker._final_emissions_data
        else:
            metrics = {"total_emissions": emissions}
        # Append the node's CodeCarbon metrics to the global list
        cc_metrics_for_file.append({
            "node": func.__name__,
            **metrics  # Flatten the metrics dictionary
        })
        return result
    return wrapper

###################################
#       PROFILE & CC DECORATORS   #
###################################

# (Assuming you already have a @profile_node decorator for timing, as in your code.)
# Here we combine both decorators so that each node is profiled for time and CodeCarbon metrics.
# The order of decorators means that cc_profile_node will wrap the function first.
def profile_node(func):
    return timing_profile_node(cc_profile_node(func))

###################################
#       LOAD URLS FROM CSV        #
###################################

def load_urls_from_csv(csv_file_path):
    urls = []
    try:
        with open(csv_file_path, 'r', newline='', encoding='utf-8') as csv_file:
            reader = csv.reader(csv_file)
            for row in reader:
                if row:  # Ensure the row is not empty
                    urls.append(row[0].strip())
    except FileNotFoundError:
        print(f"Error: CSV file '{csv_file_path}' not found.")
    except Exception as e:
        print(f"Error reading CSV file: {e}")
    return urls

###################################
#       GRAPH WORKFLOW SETUP      #
###################################

from langgraph.graph import StateGraph, START, END

# Extend GraphState to include keys from both retrieval and database parts
class GraphState(TypedDict, total=False):
    # Retrieval branch keys
    question: str                       # The user's search question or refined query
    generation: str                     # The generated answer or refined query output
    documents: List[Any]                # List of retrieved documents
    file_name: str                      # Name of the file being processed (if applicable)
    context_llm: str                    # The refined context generated by the LLM
    trace_status: str                   # Status of the processing trace
    metadata: Dict[str, Any]            # Additional metadata related to the process
    branch: str                         # Indicates which branch is being used ('retrieve' or 'web_search')
    evaluation_metrics: Dict[str, float]  # Metrics evaluating the generated results
    bert_score: Dict[str, float]        # BERTScore metrics for evaluating document support
    web_bert_score: Dict[str, float]    # BERTScore metrics for evaluating web search branch results
    skip_router: bool                   # Flag to bypass the routing node if not necessary
    
    # Research database branch keys
    input_text: str                     # The initial input text provided by the user
    raw_context: str                    # Raw context before any refinement
    search_string: str                  # The final search string generated for database queries
    databases: List[str]                # List of database names extracted from the input
    formatted_queries: Dict[str, str]   # Queries formatted specifically for each database
    db_results: Dict[str, str]          # Search results from each database (as JSON strings)
    final_output: Dict[str, Any]        # Aggregated output containing search string and database results
    
    # Additional evaluation keys for query relaxation
    min_results: int                    # Minimum number of articles expected to be returned
    max_results: int                    # Maximum number of articles desired
    relax_query: bool                   # Flag indicating whether the query should be relaxed to retrieve more results
    adjusted_query: str                 # The adjusted query string after applying relaxation or modification
    iteration: int                      # New field for iteration tracking

workflow = StateGraph(GraphState)

###################################
# RAG AGENT SETUP (Chroma/FAISS)  #
###################################

# Load RAG configuration
config_rag = load_config(CONFIG_RAG_FILE)
api_key_rag = config_rag.get("api_keys", {}).get(LLM.lower(), None)

# Initialize RAG LLM and router
LLM_RAG = config_rag.get("llm")
LLM_RAG_TEMP = config_rag.get("temperature")

if RAG_CHAT == 'OpenAI':
    llm_rag = ChatOpenAI(model=LLM_RAG, temperature=LLM_RAG_TEMP)
elif RAG_CHAT == 'LangChain':
    llm_rag = OllamaFunctions(model=LLM) 

if RAG_CHAT == 'OpenAI':
    llm_for_context = ChatOpenAI(model=LLM_RAG, temperature=LLM_RAG_TEMP)
elif RAG_CHAT == 'LangChain':
    llm_for_context = llm_LangChain

###################################
#       DATABASE BRANCH NODES     #
# (from generate_search_string onward)
###################################

# --- Define simulated search functions for each database ---
def scopus_search(query: str) -> str:

    # query = "TITLE-ABS-KEY(Digital AND twin AND federation)"
    
    # -----------------------------
    # Perform the Scopus search using Pybliometrics
    # -----------------------------
    try:
        print("Performing Scopus search for query:", query)
        # Using the STANDARD view for the initial search
        scopus_search_res = ScopusSearch(query, view="COMPLETE") # STANDARD
        # Convert the results to a DataFrame
        df_scopus = pd.DataFrame(scopus_search_res.results)
    except Exception as e:
        print("Error during Scopus search:", e)
        # Return an error string instead of calling exit(1)
        return f"Error during Scopus search: {e}"
    
    results = scopus_search_res.results
    print(f"Number of results found: {len(results)}")
    
    # -----------------------------
    # Filter the DataFrame to include only the desired columns
    # -----------------------------
    desired_columns = [
        'eid', 'doi', 'title', 'subtype', 'subtypeDescription', 'creator', 
        'affilname', 'affiliation_city', 'affiliation_country', 'coverDate', 
        'coverDisplayDate', 'publicationName', 'issn', 'source_id', 
        'aggregationType', 'authkeywords', 'citedby_count', 'openaccess'
    ]
    df_filtered = df_scopus.reindex(columns=desired_columns)
    
    # -----------------------------
    # Display the retrieved results on screen
    # -----------------------------
    # print("Displaying retrieved results:")
    # print(df_filtered)
    
    # -----------------------------
    # Save the filtered DataFrame to a CSV file
    # -----------------------------
    # csv_filename = "Scopus_Search_results.csv"
    
    try:
        df_filtered.to_csv(csv_filename, index=False, encoding="utf-8")
        print(f"\nResults have been saved to '{csv_filename}'")
    except Exception as e:
        print("Error saving CSV:", e)
    
    # -----------------------------
    # Retrieve abstracts using AbstractRetrieval for each article and save to another CSV
    # -----------------------------
    """
    abstracts_list = []
    abstracts_list_reduced = []
    print("\nRetrieving abstracts for each article...")
    
    for index, row in df_filtered.iterrows():
        try:
            # Attempt retrieval using DOI if available; otherwise, use EID.
            if pd.notnull(row['doi']) and row['doi'] != "":
                abs_obj = AbstractRetrieval(row['doi'], view="FULL")
            elif pd.notnull(row['eid']) and row['eid'] != "":
                abs_obj = AbstractRetrieval(row['eid'], view="FULL")
            else:
                print(f"No DOI or EID available for row index {index}")
                abs_obj = None
        except Exception as e:
            print(f"Error retrieving abstract for index {index}: {e}")
            abs_obj = None
    
        abstracts_list_reduced.append({
            'eid': row['eid'],
            'doi': row['doi'],
            'title': row['title'],
            'abstract': abs_obj.abstract if abs_obj is not None else None, 
            'authkeywords': abs_obj.authkeywords if abs_obj is not None else None,
            'doi-link': "http://doi.org/" + row['doi'] if row.get('doi') else None,
            'publicationName': row['publicationName'],
            'aggregationType': row['aggregationType'],
            'citedby_count': row['citedby_count'],
            'openaccess': row['openaccess']
        })
        
        abstracts_list.append({
            'eid': row['eid'],
            'doi': row['doi'],
            'title': row['title'],
            'abstract': abs_obj.description if abs_obj is not None else None, 
            'authkeywords': str(abs_obj.authkeywords) if abs_obj is not None else None, 
            'doi-link': "http://doi.org/" + row['doi'] if row.get('doi') else None,
            'subtype': row['subtype'],
            'subtypeDescription': row['subtypeDescription'],
            'publicationName': row['publicationName'],
            'publisher': str(abs_obj.publisher) if abs_obj is not None else None, 
            'authors': str(abs_obj.authors) if abs_obj is not None else None, 
            'creator':  row['creator'],
            'affilname': row['affilname'],
            'affiliation_city': row['affiliation_city'],
            'affiliation_country': row['affiliation_country'],
            'language': str(abs_obj.language) if abs_obj is not None else None, 
            'date_created': str(abs_obj.date_created) if abs_obj is not None else None, 
            'coverDate': row['coverDate'],
            'coverDisplayDate': row['coverDisplayDate'],
            'issn': row['issn'],
            'isbn': str(abs_obj.isbn) if abs_obj is not None else None, 
            'source_id': row['source_id'],
            'aggregationType': row['aggregationType'],
            'citedby_count': row['citedby_count'],
            'openaccess': row['openaccess'],
            'openaccessFlag': str(abs_obj.openaccessFlag) if abs_obj is not None else None, 
            # 'abstract': str(abs_obj.abstract) if abs_obj is not None else None,
            'refcount': str(abs_obj.refcount) if abs_obj is not None else None,
            # 'references': str(abs_obj.references) if abs_obj is not None else None
            'subject_areas': str(abs_obj.subject_areas) if abs_obj is not None else None,
            'url': str(abs_obj.url) if abs_obj is not None else None,
            'website': str(abs_obj.website) if abs_obj is not None else None,
            'freetoread': row['freetoread'],
            'freetoreadLabel': row['freetoreadLabel'],
            'volume': row['volume'],	
            'issueIdentifier': row['issueIdentifier'],	
            'article_number': row['article_number'],
            'pageRange': row['pageRange']
        })
    
    # Create a DataFrame for abstracts and display it
    df_results = pd.DataFrame(abstracts_list_reduced)
    df_abstracts = pd.DataFrame(abstracts_list)
    #print("\nDisplaying retrieved abstracts:")
    #print(df_abstracts)
    
    # Save the abstracts DataFrame to a CSV file
    # csv_abstracts_filename = "Scopus_AbstractRetrieval_results.csv"
    try:
        df_abstracts.to_csv(csv_abstracts_filename, index=False, encoding="utf-8")
        print(f"\nAbstracts have been saved to '{csv_abstracts_filename}'")
    except Exception as e:
        print("Error saving abstracts CSV:", e) 
    """

    # Convert the DataFrame to a JSON string (list of records)
    try:
        df_json = df_filtered.to_json(orient="records")
        return df_json
    except Exception as e:
        print("Error converting DataFrame to JSON:", e)
        return json.dumps({"error": f"Error converting DataFrame to JSON: {e}"})

def ieee_search(query: str) -> str:
    return f"IEEE results for query: '{query}'"

def sciencedirect_search(query: str) -> str:
    return f"ScienceDirect results for query: '{query}'"

# --- Map database names (lowercased) to tools ---
db_tools = {
    "scopus": Tool(
        name="ScopusSearch",
        func=scopus_search,
        description="Executes a search query on Scopus."
    ),
    "ieee": Tool(
        name="IEEESearch",
        func=ieee_search,
        description="Executes a search query on IEEE."
    ),
    "sciencedirect": Tool(
        name="ScienceDirectSearch",
        func=sciencedirect_search,
        description="Executes a search query on ScienceDirect."
    )
}

def extract_db_names(db_string: str) -> List[str]:
    names = [db.strip().lower() for db in db_string.split(",") if db.strip()]
    print("[extract_db_names] Input:", db_string)
    print("[extract_db_names] Extracted:", names)
    return names

def format_db_query(db: str, query: str) -> str:
    if db == "scopus":
        # If the query already starts with TITLE-ABS-KEY/ALL, return it as-is.
        if query.strip().upper().startswith("ALL"):  # TITLE-ABS-KEY
            return query
        else:
            return f"ALL({query})"  # TITLE-ABS-KEY
            # return f"{query}"
    elif db == "ieee":
        formatted = f"INDEXTERMS({query})"
    elif db == "sciencedirect":
        formatted = f"KEY({query})"
    else:
        formatted = query
    print(f"[format_db_query] Database: {db} | Query: {query} | Formatted: {formatted}")
    return formatted

def multi_db_search(query: str, dbs: str) -> str:
    db_list = extract_db_names(dbs)
    aggregated_results = {}
    print("[multi_db_search] Databases:", db_list)
    for db in db_list:
        if db in db_tools:
            formatted_query = format_db_query(db, query)
            print("Formatted Scopus Query:", formatted_query)
            tool = db_tools[db]
            result = tool.func(formatted_query)
            aggregated_results[db] = result
        else:
            aggregated_results[db] = f"No tool available for database '{db}'."
            # Simulate executing the search
            #aggregated_results[db] = f"Simulated result for {formatted_query}" ############### ADD FUNCTION FOR SCOPYUS
    aggregated_json = json.dumps(aggregated_results)
    # print("[multi_db_search] Aggregated Results:", aggregated_json)
    return aggregated_json

def multi_db_search_wrapper(state: Dict[str, Any], tool_input: str) -> str:
    try:
        data = json.loads(tool_input)
        query = data.get("query", "")
        dbs = data.get("dbs", "")
        print("[multi_db_search_wrapper] Input JSON:", data)
        return multi_db_search(query, dbs)
    except Exception as e:
        error_message = f"Error parsing tool input: {str(e)}"
        
        # Initialize iteration counter if not present, then increment it
        if "iteration" not in state:
            state["iteration"] = 0
        state["iteration"] += 1
        print(f"[generate_search_string_node] Current iteration: {state['iteration']}")
    
        # Define maximum iterations allowed before forcing a relaxed query
        max_iterations = 10
        if state["iteration"] >= max_iterations:
            print("[generate_search_string_node] Maximum iterations reached; switching to 'broaden' mode.")
            state["adjust_query"] = "broaden"    
        
        print("[multi_db_search_wrapper] Error:", error_message)
        return error_message

@profile_node
def evaluate_results_node(state: GraphState) -> GraphState:
    """
    Checks the total number of results obtained from the database search.
    If the number of results is below the minimum threshold, sets a flag to broaden the query.
    If the number is above the maximum threshold, sets a flag to tighten the query.
    Otherwise, no adjustment is needed.
    """

    total_results = 0
    for db, json_str in state.get("db_results", {}).items():
        try:
            db_obj = json.loads(json_str)
            # If the result object contains a key (like "scopus") with a JSON string
            if isinstance(db_obj, dict) and "scopus" in db_obj:
                articles = json.loads(db_obj["scopus"])
            elif isinstance(db_obj, list):
                articles = db_obj
            else:
                articles = []
            total_results += len(articles)
        except Exception as e:
            print(f"Error parsing results for {db}: {e}")
    
    print(f"Total results found: {total_results}")
    
    # Determine if the query needs adjustment
    if total_results < min_threshold:
        # Not enough articles: broaden the query (remove some constraints)
        state["relax_query"] = True
        state["adjust_query"] = "broaden"
        print("Insufficient results; triggering relaxed (broaden) search string generation.")
    elif total_results > max_threshold:
        # Too many articles: tighten the query (add more constraints)
        state["relax_query"] = True
        state["adjust_query"] = "tighten"
        print("Too many results; triggering tightened search string generation.")
    else:
        state["relax_query"] = False
        state["adjust_query"] = "none"
        print("The number of results is within the desired range; no adjustment needed.")
    
    return state

# After obtaining search_text from the LLM
# Remove any numbers inside quotes, but leave the PUBYEAR constraint intact.
# This regex removes numbers inside quotes (if any), but you may need to adjust it to your specific output format.
def remove_numeric_keywords(search_str: str) -> str:
    # This pattern removes digits that are inside double quotes, e.g., "2023"
    return re.sub(r'"(\d+)"', '""', search_str)


@profile_node
def generate_search_string_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Generates an optimized search string using the input text.
    If 'relax_query' is True, the query is adjusted according to 'adjust_query':
      - 'broaden': generates a broader query by either adding alternative keywords (joined by OR) within groups or by removing one or more entire groups (joined by AND),
         based on the current iteration.
      - 'tighten': generates a more specific query by adding more constraints.
    The prompt uses the iteration number to vary the search string.
    Also appends a publication year constraint if the input text contains a time limit.
    """
    query_user_input = state["input_text"]
    metadata_path = os.path.join(base_output_json_dir, filename)

    # Initialize iteration counter if not present, then increment it.
    if "iteration" not in state:
        state["iteration"] = 0
    state["iteration"] += 1
    current_iter = state["iteration"]
    print(f"[generate_search_string_node] Current iteration: {current_iter}")

    # Define maximum iterations allowed before forcing a relaxed query.
    max_iterations = 10
    if current_iter >= max_iterations:
        print("[generate_search_string_node] Maximum iterations reached; switching to 'broaden' mode.")
        state["adjust_query"] = "broaden"    

    adjust_query = state.get("adjust_query", "none")

    # Define the system prompt according to adjustment mode.
    # TITLE-ABS-KEY
    if adjust_query == "broaden":
        system_prompt = (
            "You are an expert at generating search strings for research queries. "
            "The current iteration number is {iteration}. "
            "Generate a broader search string by relaxing the constraints in two ways: "
            "1) Randomly insert a number of new keywords (between 1 and 5) as additional OR alternatives within existing groups; "
            "2) Randomly remove a number of keywords (between 1 and 5) from groups, while keeping the most representative keywords. "
            "For each group, add synonyms or domain-related keywords using OR so that the group covers alternative terms. "
            "Each group represents a parallel domain defined by the research questions, and groups are combined with AND. "
            "Each keyword must be a natural phrase or term, and must use spaces between words. Do not concatenate multiple words into a single token."
            "Return the search string in the following exact format (do not include any extra text):\n\n"
            "ALL ( ( Keyword1 OR Keyword2 OR ... OR KeywordN ) AND "
            "( KeywordA OR KeywordB OR ... OR KeywordM ) AND ... )\n\n"
            "Do not include the publication year as a keyword."
        )
    elif adjust_query == "tighten":
        system_prompt = (
            "You are an expert at generating search strings for research queries. "
            "The current iteration number is {iteration}. "
            "Generate a more specific search string by adding additional specific keywords to each group while still respecting the research questions and goals. "
            "Ensure that within each group keywords are combined with OR and groups are combined with AND. "
            "All keywords must be unique. "
            "Each keyword must be a natural phrase or term, and must use spaces between words. Do not concatenate multiple words into a single token."
            "Return the search string in the following exact format (do not include any extra text):\n\n"
            "ALL ( ( Keyword1 OR Keyword2 OR ... OR KeywordN ) AND "
            "( KeywordA OR KeywordB OR ... OR KeywordM ) AND ... )\n\n"
            "Do not include the publication year as a keyword."
        )
    else:
        system_prompt = (
            "You are an expert at generating search strings for research queries. "
            "The current iteration number is {iteration}. "
            "Generate a search string using only the logical operators OR and AND, where within each group keywords are combined with OR and groups are combined with AND. "
            "Ensure that all keywords in the search string are unique and that the search string varies with each iteration. "
            "Each keyword must be a natural phrase or term, and must use spaces between words. Do not concatenate multiple words into a single token."
            "Return the search string in the following exact format (do not include any extra text):\n\n"
            "ALL ( ( Keyword1 OR Keyword2 OR ... OR KeywordN ) AND "
            "( KeywordA OR KeywordB OR ... OR KeywordM ) AND ... )\n\n"
            "Do not include the publication year as a keyword."
        )

    prompt_template = ChatPromptTemplate.from_messages([
        ("system", system_prompt),
        ("user", 
         f"User question:\n{query_user_input}\n\n"
         "Based on the user question and the current iteration, generate an optimal search string composed solely of essential, unique keywords. "
         "Use the following structure: within each group, keywords are combined with OR; between groups, use AND. "
         "If necessary to increase the number of results, modify the query by either adding new keywords (1 to 5) in OR or by removing one or more groups. "
         "Return only the search string in the exact format specified above.\n\n"
         "Search String:"
         )
    ])

    # search_string = (prompt_template | llm_for_context).invoke({"context_llm": context_llm})
    # Invoke the LLM chain for string generation
    start_time_llm = time.time()
    search_string = (prompt_template | llm_for_context).invoke({"iteration": current_iter}) 

    print(f"Final String:{search_string}")
    end_time_llm = time.time()
    execution_time = end_time_llm - start_time_llm
    
    # Extract the string from the result
    if LLM_TYPE != 'Ollama':
        string_output = search_string.content.strip()
    else:
        # string_output = search_string.strip()
        # Extract the string from the result
        if hasattr(search_string, "content"):
            string_output = search_string.content.strip()
        else:
            string_output = str(search_string).strip()
    
    # Build metadata for the response
    if LLM_TYPE != 'Ollama':
        metadata = {
            "response_length": len(string_output),
            "execution_time": execution_time,
            "temperature": temperature,
            "usage": search_string.usage_metadata,
            "price_usd": search_string.usage_metadata.get("input_tokens", 0) * PRICE_PER_INPUT_TOKEN +
                         search_string.usage_metadata.get("output_tokens", 0) * PRICE_PER_OUTPUT_TOKEN,
            "model_name": model_name
        }
    else:
        metadata = {
            "response_length": len(string_output),
            "execution_time": execution_time,
            "temperature": temperature,
            "model_name": model_name
        }

    # Save the string and metadata to output files
    save_metadata(metadata_path, metadata)

    print(f"Metadata saved to: {metadata_path}")
    
    search_text = search_string.content if hasattr(search_string, "content") else str(search_string)
    search_text = search_text.strip()

    cleaned_search_text_year = remove_numeric_keywords(search_text)
    # Further clean-up if needed (e.g., remove extra spaces)
    cleaned_search_text_year = re.sub(r'\s+', ' ', cleaned_search_text_year).strip()

    # If the user has entered a time constraint in the input, such as "after 2020", add the PUBYEAR constraint.
    year_match = re.search(r'after\s+(\d{4})', state.get("input_text", ""), re.IGNORECASE)
    if year_match:
        year = year_match.group(1)
        cleaned_search_text_year += f" AND PUBYEAR > {year}"
    
    cleaned_search_text = re.sub(r'\s*site:\S+(?:\s*OR\s*site:\S+)+', '', cleaned_search_text_year)
    state["search_string"] = cleaned_search_text
    print("[generate_search_string_node] Search string (final):", state["search_string"])
    return state

@profile_node
def extract_databases_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extracts a list of database names from the input_text.
    It looks for an explicit "databases:" segment and also searches for known
    database names (e.g., scopus, ieee, sciencedirect) anywhere in the input.
    If no known databases are found, the system notifies the user and requests
    a valid database to be entered.
    """
    input_text = state.get("input_text", "")
    print("[extract_databases_node] Input text:", input_text)
    lower_text = input_text.lower()
    extracted_dbs = []
    
    # First, check for an explicit "databases:" segment.
    if "databases:" in lower_text:
        idx = lower_text.find("databases:")
        db_string = input_text[idx + len("databases:"):].strip()
        extracted_dbs = extract_db_names(db_string)
    
    # List of known database names to search for.
    known_dbs = ["scopus", "ieee", "sciencedirect"]
    # Check if any of the known database names appear in the input text.
    for db in known_dbs:
        if db in lower_text and db not in extracted_dbs:
            extracted_dbs.append(db)

    # If no known databases were found, notify the user and prompt for input.
    if not extracted_dbs:
        print("[extract_databases_node] No known databases found in the input.")
        valid_db_input = input("Please enter a valid database (e.g., scopus, ieee, sciencedirect): ")
        extracted_dbs = extract_db_names(valid_db_input)
        # Validate that the provided database is one of the known ones.
        valid_extracted_dbs = [db for db in extracted_dbs if db in known_dbs]
        if not valid_extracted_dbs:
            print("[extract_databases_node] Invalid database entered. Please try again.")
            valid_db_input = input("Please enter a valid database (e.g., scopus, ieee, sciencedirect): ")
            extracted_dbs = extract_db_names(valid_db_input)
            valid_extracted_dbs = [db for db in extracted_dbs if db in known_dbs]
        extracted_dbs = valid_extracted_dbs
    
    state["databases"] = extracted_dbs
    print("[extract_databases_node] Databases extracted:", state["databases"])
    return state

@profile_node
def format_queries_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Formats the search_string into specific queries for each database.
    """
    search_string = state.get("search_string", "")
    databases = state.get("databases", [])
    print("[format_queries_node] Search string:", search_string)
    print("[format_queries_node] Databases:", databases)
    formatted = {}
    for db in databases:
        formatted[db] = format_db_query(db, search_string)
    state["formatted_queries"] = formatted
    print("[format_queries_node] Formatted queries:", state["formatted_queries"])
    return state

@profile_node
def run_db_search_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Executes the search for each database using the formatted queries.
    """
    formatted_queries = state.get("formatted_queries", {})
    print("[run_db_search_node] Formatted queries:", formatted_queries)
    results = {}
    for db, f_query in formatted_queries.items():
        tool_input = json.dumps({"query": f_query, "dbs": db})
        result = multi_db_search_wrapper(state, tool_input)
        results[db] = result
    state["db_results"] = results
    # print("[run_db_search_node] Database results:", state["db_results"])
    return state

@profile_node
def aggregate_results_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Aggregates the search_string and database results into a final output.
    """
    state["final_output"] = {
        "search_string": state.get("search_string", ""),
        "db_results": state.get("db_results", {})
    }
    # print("[aggregate_results_node] Final output:", state["final_output"])
    return state

###################################
#         ROUTER QUESTION         #
###################################

# Starting node
workflow.add_edge(START, "generate_search_string") # generate_query

# Database branch nodes (from generate_search_string onward)
workflow.add_node("generate_search_string", generate_search_string_node)
workflow.add_node("extract_databases", extract_databases_node)
workflow.add_node("format_queries", format_queries_node)
workflow.add_node("run_db_search", run_db_search_node)
workflow.add_node("aggregate_results", aggregate_results_node)

# After cache_context, proceed with database branch nodes
# workflow.add_edge("cache_context", "generate_search_string")
workflow.add_edge("generate_search_string", "extract_databases")
workflow.add_edge("extract_databases", "format_queries")
workflow.add_edge("format_queries", "run_db_search")

# Add the node to evaluate the results
workflow.add_node("evaluate_results", evaluate_results_node)

# Modify the flow: after run_db_search, proceed to evaluate_results
workflow.add_edge("run_db_search", "evaluate_results")

# If evaluate_results sets relax_query = True, redirect to generate_search_string
workflow.add_conditional_edges(
    "evaluate_results",
    lambda state: "relax" if state.get("relax_query") else "continue",
    {
        "relax": "generate_search_string",  # Return to the node to regenerate the query (in relax mode)
        "continue": "aggregate_results"     # Continue with aggregation if the results are sufficient
    },
)

# workflow.add_edge("run_db_search", "aggregate_results")
workflow.add_edge("aggregate_results", END)

# (Optional) Graph visualization
"""
try:
    from IPython.display import display, Markdown, Image
    graph = workflow.compile().get_graph()
    graph.mermaid_config = {"graph_direction": "TD"}
    png_bytes = graph.draw_mermaid_png()
    with open("graph.png", "wb") as f:
        f.write(png_bytes)
    print("The graph has been saved as 'graph.png'.")
    display(Markdown("### LangGraph Visualization ###"))
    display(Image(png_bytes))
except Exception as e:
    print("Graph rendering failed:", e)
"""

###################################
#       EXECUTION OF WORKFLOW     #
###################################

if __name__ == "__main__":
    # Initial state: note that for the database part, input_text is required
    """
    user_text = (
                Purpose To identify and classify existing solutions leveraging the combination of MDE, DevOps, and AI/ML principles and practices supporting the system and software engineering of cyber-physical systems from the point of view of researchers and practitioners.
                RQ1 : Is there a systems and software engineering methodology that explicitly incorporates and integrates the principles and practices of MDE, AI/ML, and DevOps research areas? If such a methodology exists, how does it combine these research areas?
                RQ2 : Are the principles and practices of MDE, AI/ML, and DevOps research areas integrated throughout the entire process, or are they applied to specific engineering activities?
                RQ3 : Which research fields and application domains are the target of these approaches?
                RQ4 : What are the future research directions?
                Database: scopus
                Pubblication year higher then 2006
    )
    """

    user_text = ("""
Research Questions:

RQ0: What are the bibliometric key facts of peer-reviewed literature documenting applications of MDE to DTs?
RQ0.1: In which years are they published?
RQ0.2: In which types of venues are they published?

RQ1: How and how often are automation techniques applied to DTs in peer-reviewed literature?
RQ1.1: How often are the different automation techniques applied in the context of DTs?
RQ1.2: Which modeling artifacts and software artifacts are used by these automation techniques?
RQ1.3: Which combinations of input and output artifacts are used by these automation techniques?
RQ1.4: What is the research type of the studies that apply these automation techniques?

RQ2: To which types of DTs are automation techniques applied in peer-reviewed literature?
RQ2.1: Which TT does the DT represent, in which SLCP of the TT are automation techniques applied, and what is the TLCP of DTs to which automation techniques are applied?
RQ2.2: How does the application of automation techniques (identified with RQ1) vary for different DT types?

RQ3: In which domains are automation techniques applied to DTs in peer-reviewed literature?
RQ3.1: For which domains are automation techniques applied to DTs?
RQ3.2: How does the application of automation techniques (identified with RQ1) vary for the identified domains?
RQ3.3: How does the DT type (identified with RQ2) vary for the identified domains?

Database: Scopus
            """
    )
    
    # user_text = input("Enter the search text: ")
    
    initial_state: GraphState = {
        "input_text": user_text
    }
    final_state = list(workflow.compile().stream(initial_state, config={"recursion_limit": 100}))[-1]
    
    print("\n--- FINAL WORKFLOW STATE ---\n")
    
    # If the useful data is under "final_output", use that; otherwise, use final_state directly.
    data = final_state.get("aggregate_results", final_state)
    
    # Print the various elements
    print("\n--- FINAL WORKFLOW STATE ---\n")
    
    print("Input Text:")
    print(json.dumps(data.get("input_text", {}), indent=4))
    
    print("\nDatabases:")
    print(json.dumps(data.get("databases", {}), indent=4))
    
    print("\nFormatted Queries:")
    print(json.dumps(data.get("formatted_queries", {}), indent=4))
    
    print("\nSearch String:")
    print(json.dumps(data.get("search_string", {}), indent=4))
    
    print("\nDB Results:")
    #print(json.dumps(data.get("db_results", {}), indent=4))
    db_results = data.get("db_results", {})

    # For each database (here, "scopus")
    for db, json_str in db_results.items():
        print(f"\nResults for {db}:")
    
        # Convert the JSON string into a Python object
        try:
            # In our case the value is something like '{"scopus": "[{...}, {...}]" }'
            db_obj = json.loads(json_str)
        except Exception as e:
            print(f"Error parsing results for {db}: {e}")
            continue
    
        # If the structure has a key (like "scopus") containing a JSON string,
        # parse that as well.
        if isinstance(db_obj, dict) and "scopus" in db_obj:
            try:
                articles = json.loads(db_obj["scopus"])
            except Exception as e:
                print(f"Error parsing articles for {db}: {e}")
                continue
        else:
            articles = db_obj
    
        # Now, iterate through the list of articles and print each article in a pretty format
        for article in articles:
            print(json.dumps(article, indent=4))
            print("-" * 50)
    
    # Print aggregate results (if available)
    """
    if "aggregate_results" in final_state:
        print("Aggregate Results:")
        print(json.dumps(final_state["aggregate_results"], indent=4))
        print("\n" + "="*40 + "\n")
    else:
        print("No aggregate_results found.")
    """
    
    # Optionally, print the entire final_state
    # print("Full Final State:")
    # print(json.dumps(final_state, indent=4))

[generate_search_string_node] Current iteration: 1
Final String:content='ALL ( ( model driven engineering OR MDE OR digital twins OR DTs ) AND ( automation techniques OR peer reviewed literature OR bibliometric analysis ) AND ( modeling artifacts OR software artifacts OR input artifacts OR output artifacts ) AND ( research type OR study application OR domain application ) )' additional_kwargs={'refusal': None} response_metadata={'token_usage': {'completion_tokens': 55, 'prompt_tokens': 576, 'total_tokens': 631, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'finish_reason': 'stop', 'logprobs': None} id='run-b5997950-36ae-4327-92fe-217ec1cbebd5-0' usage_metadata={'input_tokens': 576, 'output_tokens': 55, 'total_tokens': 631, 'input_token_details': {'audio': 

  df = pd.concat([df, pd.DataFrame.from_records([dict(total.values)])])


[extract_databases_node] Input text: 
Research Questions:

RQ0: What are the bibliometric key facts of peer-reviewed literature documenting applications of MDE to DTs?
RQ0.1: In which years are they published?
RQ0.2: In which types of venues are they published?

RQ1: How and how often are automation techniques applied to DTs in peer-reviewed literature?
RQ1.1: How often are the different automation techniques applied in the context of DTs?
RQ1.2: Which modeling artifacts and software artifacts are used by these automation techniques?
RQ1.3: Which combinations of input and output artifacts are used by these automation techniques?
RQ1.4: What is the research type of the studies that apply these automation techniques?

RQ2: To which types of DTs are automation techniques applied in peer-reviewed literature?
RQ2.1: Which TT does the DT represent, in which SLCP of the TT are automation techniques applied, and what is the TLCP of DTs to which automation techniques are applied?
RQ2.2: How doe


KeyboardInterrupt

