# IKAT 2024

In [None]:
import os
import json
import textwrap
import ast
from datetime import datetime
import pandas as pd
import numpy as np

from pyserini.search import LuceneSearcher
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer, util

from src.utils import use_llama, prepare_output_for_json, join_and_cap_passages, get_true_ptkbs, has_qrels_for_conversation, combine_performance_metrics, combine_hits, full_evaluations_join, run_generate_run, measure_intrarun_ranking_performance, check_quick_retrieval_success, extract_evaluations, select_turns_by_prefix, run_trec_eval, extract_relevant_passages, passage_neural_retrieval


import settings


## Functions

In [None]:
def llama_ptkb_ranker(model, tokenizer, query, ptkb):
    
    # save all the prints to a text file for tracking system progress
    with open(verbose_output_filename, 'a') as f:
        f.write('\n\n### Available PTKB statements are:\n')
        for key in sorted(ptkb.keys(), key=int):
            f.write(f"{key}. {ptkb[key]}\n")
    

    # get response
    relevant_ptkb_statements_ids = use_llama(
        model=model,
        tokenizer=tokenizer,
        system_message_content="""
        You are an Assistant that returns ONLY console output, without any additional words, niceties or verbosity of any kind. 
        
        # Instructions 
        
        // You are instructed to return a JSON and you are to return only the JSON object and nothing else. 
        // You have been assigned a classification task. The user will provide you with a conversation ending with a latest user reply and a list of statements. 
        // YOU MUST CONSIDER WHICH STATEMENTS ARE THE MOST RELEVANT TO ANSWER THE LATEST USER REPLY.
        // Your job is to remove the statements are irrelevant to the text query. Statements that are only somewhat related should not be eliminated. Only eliminate if not at all related.
        // Then, assign scores of how related the statement is to the query with numbers between 1 and 0, where 1 is the highest. 
        // I REPEAT: ONE IS THE HIGHEST AND ZERO IS THE LOWEST. 
        // You can only return statements from the provided list. I REPEAT: You can only return statements from the provided list.
        // You will return only a JSON of the `id` and `score`. Nothing else. 
        // You must make sure the IDs are returned as strings, regardless that they are integers. 
        // Don't be overly stringent with your definition or relevant. Loosely related is enough. 
        // Do not include statements that are irrelevant to the text passage. I REPEAT: DO NOT include statements if they are not related at all. 
        
        # Examples 
                
        Input:
        The Latest User Reply is "Can you tell me the best place to find Pokemon?" and the statements are 
        {'1': "I'm amazed at the fact that the human body has 206 bones.", 
        '2': "I started my journey as a Pokémon trainers at the age of ten.", 
        '3': "My favorite cuisine is Italian, especially pizza and pasta.", 
        '4': "I enjoy hiking in the mountains during my free time.", 
        '5': "I love the Eiffel Tower as it is one of the most famous landmarks in Paris.", 
        '6': "I believe Pikachu is the most iconic and recognizable Pokémon in the world."}       
        
        Required Output: 
        `[{"id": "2", "score": 0.98}, {"id": "6", "score": 0.81}]`
        

        Input:
        The Latest User Reply is "I wish to get facts on the importance of education?" and the statements are
        {
        '1': "I believe that education is crucial for personal development.",
        '2': "I enjoy hiking in the mountains during my free time.",
        '3': "In my opinion, continuous learning is important for staying relevant in any field.",
        '4': "I think education should be accessible to everyone, regardless of their background.",
        '5': "My favorite cuisine is Italian, especially pizza and pasta.",
        '6': "I feel that a strong educational foundation is essential for success in life."
        }
        
        Required Output:
        `[{"id": "1", "score": 1.0}, {"id": "3", "score": 0.85}, {"id": "4", "score": 0.75}, {"id": "6", "score": 0.7}]`

        Input:
        The Latest User Reply is "What are your favorite foods?" and the statements are
        {
        '1': "I love Italian cuisine, especially pasta.",
        '2': "I have a cat named Whiskers.",
        '3': "Chocolate cake is my go-to dessert.",
        '4': "I enjoy swimming in my free time.",
        '5': "Spicy foods are a favorite of mine.",
        '6': "I recently visited Japan and loved the sushi there."
        }
        
        Required Output:
        `[{"id": "1", "score": 0.99}, {"id": "3", "score": 0.85}, {"id": "5", "score": 0.7}, {"id": "6", "score": 0.4}]`
       
       
       Thank you for the clarification! Here’s a new version of the example with only relevant statements about Paris and the rest being random personal facts starting with "I...".

        ### Input:  
        The Latest User Reply is "Can you share some insights on traveling to Paris?" and the statements are:  
        {
        '1': "I love visiting famous landmarks like the Eiffel Tower in Paris.",  
        '2': "I enjoy painting landscapes during my free time.",  
        '3': "The Louvre is one of the most visited museums in Paris.",  
        '4': "I recently adopted a puppy named Max.",  
        '5': "I have been learning French for the past two years.",  
        '6': "I think Paris is one of the most beautiful cities in the world."  
        }

        ### Required Output:  
        `[{"id": "1", "score": 1.0}, {"id": "3", "score": 0.9}, {"id": "6", "score": 0.85}]`
        
        SCORE OF ZERO ARE NOT ALLOWED.

        
        """,
        user_message_content=f"""The text query is "{query}" and the statements are {ptkb}""")
        # NOTE: The examples above are LLM generated.
    
  # save all the prints to a text file for tracking system progress
    with open(verbose_output_filename, 'a') as f:
        f.write('\n\n### Raw output\n')
        f.write(f'{relevant_ptkb_statements_ids}')

    # convert from string
    relevant_ptkb_statements_ids = ast.literal_eval(relevant_ptkb_statements_ids)


    # create a list of dictionaries with the relevant PTKB statements, including the score
    relevant_ptkb_statements = [
        {'id': item['id'], 'statement': ptkb[item['id']], 'score': item['score']}
        for item in relevant_ptkb_statements_ids
    ]
    
    # sort the list of dictionaries by score, from highest to lowest
    relevant_ptkb_statements = sorted(relevant_ptkb_statements, key=lambda x: x['score'], reverse=True)

    return relevant_ptkb_statements



In [None]:
def generate_top_response(conversation_history_for_model, relevant_ptkbs, reranked_passages, searcher: LuceneSearcher, num_passages: int, summarizer, tokenizer, current_utterance, pre_reranking=False):

    
    # split the retrieved documents into groups of num_passages for generating responses
    top_docs_for_generating_response = reranked_passages[:num_passages]
    
    # extract passages from the search results
    top_passages = [json.loads(searcher.doc(hit.docid).raw())['contents'] for hit in top_docs_for_generating_response]
    

    
    # join all passages into a single string / cap at 300 words
    text = join_and_cap_passages(top_passages, max_words=300)
    
    ptkb_statements_string = ' '.join([entry['statement'] for entry in relevant_ptkbs])

    
    llama_user_inputs = f"""{current_utterance}"""
    
        
    if pre_reranking == True:
        llama_reponse = "Pre-reranking. Not needed." 
    
    else:
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write('\nIn:\n') 
            f.write('```\n') 
            f.write(f'{llama_user_inputs}\n')        
            f.write('```\n') 
        
        
        system_message_content=f"""
                You are an Conversational Assistant part of InfoRetCo, a chat service company. You are currently having a conversation with a loyal customer. Over this customer's time with the company, we have gathered the following information about them: "{ptkb_statements_string}".
                
                We have also gathered the following Answer Provenance Passages (APP) for response generation:\n{text}
                
                # Instructions
                
                // You must complete the conversation below using ONLY the provided Answer Provenance Passages (APP):
                // 1. Your reponse must be a maximum of 300 words. I REPEAT: DO NOT EXCEED THE 300 WORD LIMIT.
                // 2. Respond without adding, infering, or assuming any details outside the provided provenance passages. I REPEAT: Strictly adhere to the exact information given in the provenance passagess without introducing any new elements or context. 
                // 3. Your reponse must not cite the provided provenance passages.
                // 4. You must provide a single text string. There must not be any additional text surrounding the response.
                // 5. Your response must begin with "Assistant:  ..."

                The Conversation History Context is as follows:{conversation_history_for_model}
            """
        
        # get response
        llama_reponse = use_llama(
            model=summarizer,
            tokenizer=tokenizer,
            system_message_content=system_message_content,
            user_message_content=llama_user_inputs,
        )
    
        llama_reponse = llama_reponse.replace('Assistant: ', '')
    
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write('\n## Prompt:\n')
            f.write('```')
            f.write(f'\n{system_message_content}\n')
            f.write('```\n')            
            f.write('\nOut:\n')
            f.write('```')
            f.write(f'\n{textwrap.fill(llama_reponse, width=100)}\n')
            f.write('```\n')
    
    return [llama_reponse], [top_docs_for_generating_response] 




In [None]:
def get_query_llama(ptkb, turn, previous_turn, conversation_history, ptkb_qrels, query_id, skip_turns_without_qrels, expected_qrel_passages, model=None, tokenizer=None):
    
    # variables
    utterance = turn['utterance']
    
    # if it's the first turn
    if previous_turn != '':
        previous_response = previous_turn['response']
        
    # init variables if non-first term
    previous_response_summary = ""
    conversation_history_for_model = ""
    conversation_history_for_model_2 = ""
    
    
    # Conversation history management
    #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
    #################################
    # Previous response summarization
    #################################
    

    # no assistant response during the first round
    if previous_turn != "":
        
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\n## Response summarization model\n")
            f.write(F"{'-'*100}\n")
            f.write(f"\nIn:")
            f.write(f"\n```\n")
            f.write(f"{previous_response}\n")
            f.write(f"```\n")
        
        
        if len(previous_response) < 150: # only summarize with Llama if previous response is less than 150 chars
            previous_response_summary = previous_response
        else:
        
            
            previous_response_summary_user_message_content = f"""{previous_response}\nThe 150 character response is:"""
                
            previous_response_summary = use_llama(
                model=model,
                tokenizer=tokenizer,
                system_message_content="""
                    
                    You are an Summarization Robot part of a broader Information Retrieval system.
                    
                    # Instructions
                    
                    // You will be provided with a text input for which you need to generate a text output and abide to the following policy:
                    // 1. Provide a summary the text input of a 150 character maximum of the passage. I REPEAT: DO NOT EXCEED THE 150 CHARACTER LIMIT.
                    // 2. Paying particular attention, but not exclusively, to any names, topics, and themes that will be useful to recall in the future. 
                    // 3. Respond without adding, infering, or assuming any details outside the provided text. 
                    // 4. I REPEAT: Strictly adhere to the exact information given in the passage without introducing any new elements or context. 
                    // 5. YOU DO NOT HELP THE USER, YOU PROVIDE THE REQUEST RESPONSE AND NOTHING ELSE. You provide only the response without any introductory or concluding statements or labels. I REPEAT: YOU MUST NOT ANWER THE USER'S QUESTION.
                    """,
                user_message_content=previous_response_summary_user_message_content,
                
                )
            
        
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\nOut:")
            f.write(f"\n```\n")
            f.write(f"{previous_response_summary}\n")
            f.write(f"```\n")
        
        
        # declutter the conversation history
        
        # split the conversation history into lines
        lines = conversation_history.strip().split('\n')

        # check if there are more than 5 lines
        if len(lines) > 5:
            # keep only the last 5 lines
            lines = lines[-5:]

        # join the lines back into a string
        conversation_history = "\n".join(lines)
        conversation_history = f'\n{conversation_history}'
                
        ########## add previous response to conversation history to use in this turn        
        conversation_history_for_model = f"{conversation_history}\n\n- **Latest Assistant Reply**: {previous_response}"
        conversation_history_for_model_2 = f"{conversation_history}\n- Assistant: {previous_response}"
        conversation_history = f"{conversation_history}\n- Assistant: {previous_response_summary}"
        

    # skip based on settings
    if skip_turns_without_qrels == True and not expected_qrel_passages:
        
        query_list=""
        reranking_list=""
        relevant_ptkb_statements=""
        ptkb_query_weights=""
        
    else:
        
        
        
                    
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            # true ptkbs
            f.write(f"\n\n\nPassage Ranking Task:\n")
            f.write(f"{'#'*100}")
        
        

        # PTKB Ranking
        #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
        #################################
        # PTKB Ranking
        #################################
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            # true ptkbs
            f.write(f"\nPTKB Ranking Task:\n")
            f.write(f"{'#'*100}")
            f.write('\n\n### True PTKBs according to the qrel:\n')
            f.write(f"\n{get_true_ptkbs(ptkb_qrels=ptkb_qrels, conversation_id=query_id)}\n")
        
        
        llama_ptkb_ranker_query = f"""\n**Conversation History Context:**\n{conversation_history_for_model}\n\n- **Latest User Reply:** {utterance}"""
        
        relevant_ptkb_statements = llama_ptkb_ranker(
                    # Query is the combined context and utterance
                    query=llama_ptkb_ranker_query,
                    ptkb=ptkb,                        # pTKB data source
                    model=model,
                    tokenizer=tokenizer,
        )
        
        #ptkb_statements_string = ' '.join([entry['statement'] for entry in relevant_ptkb_statements[:3]]) # just the top 3
        ptkb_statements_string_comma = ', '.join([entry['statement'] for entry in relevant_ptkb_statements]) # just the top 3
        #ptkb_statements_string_full = ' '.join([entry['statement'] for entry in relevant_ptkb_statements]) # full

        relevant_ptkb_statements_loop = relevant_ptkb_statements[:3]
            
        # get weighths
        ptkb_query_weights = []

        for entry in relevant_ptkb_statements_loop:
            ptkb_score = entry['score'] 

            ptkb_query_weights.append(ptkb_score)
        #ptkb_query_weights.insert(0, 1)  # Adds 1 to the start of the list for the original query, the rest for the PTKB specific
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\n## Llama PTKB Ranking\n")
            f.write(F"{'-'*100}\n")
            f.write(f"\nIn:")
            f.write(f"\n```")
            f.write(f"\n{llama_ptkb_ranker_query}\n")
            f.write(f"```\n")            
            f.write(f"\nOut:\n")
            for i, item in enumerate(relevant_ptkb_statements, 1):
                f.write(f"{i}. Id: {item['id']}. Score: {item['score']}. Statement: {item['statement']}\n")
        
        with open(verbose_output_filename, 'a') as f:
            # true ptkbs
            f.write(f"\n\nPTKB Query Weights:\n\n")
            f.write(f"{ptkb_query_weights}")
        
        
        
        # Query Generation
        #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
        
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\n\nQuery Comprehension:\n")
            f.write(f"{'#'*100}")
            f.write(f"\n\n--> Expected comprehension <--")
            f.write(f"\n```")
            f.write(f"\n{turn['response']}\n")        
            f.write(f"```\n")
        
        
        #################################
        # Short Passage Query
        #################################

        short_passquery_user_message_content = f"""{utterance} I REPEAT: {utterance}"""
            
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\n## Generate short answer\n")
            f.write(F"{'-'*100}\n")
            f.write(f"\nIn:")
            f.write(f"\n```")
            f.write(f"\n{short_passquery_user_message_content}\n")
            f.write(f"```\n")
        
        system_message_content=f"""
            You are an Conversational Assistant part of InfoRetCo, a chat service company. You are currently having a conversation with a loyal customer. Over this customer's time with the company, we have gathered the following information about them: "{ptkb_statements_string_comma}".

            # Instructions
            
            1. Continue the conversation by responding to the user
            2. Your response must begin with "Assistant: ..."
            3. The response must be exactly 5 full sentences long.
            4. You are not allowed to ask follow up questions to the user. I REPEAT: You are not allowed to ask follow up questions to the user.
            
            
            The Conversation History Context is as follows:{conversation_history_for_model_2}
            """

        short_passquery_text = use_llama(
            model=model,
            tokenizer=tokenizer,
            system_message_content=system_message_content,
            user_message_content=short_passquery_user_message_content,
            
            ) 
        
        # clean
        short_passquery_text = short_passquery_text.replace("Assistant: ", "")
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            #f.write(f"\n{'#'*100}\n")
            f.write(f"\nOut:")
            f.write(f"\n```")
            f.write(f"\n{system_message_content}\n")
            f.write(f"```")            
            f.write(f"\nOut:")
            f.write(f"\n```")
            f.write(f"\n{short_passquery_text}\n")
            f.write(f"```")

        
        #################################
        # Long Passage Query
        #################################

        long_passquery_user_message_content = f"""{utterance} I REPEAT: {utterance}"""            
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\n## Generate Article\n")
            f.write(F"{'-'*100}\n")
            f.write(f"\nIn:")
            f.write(f"\n```")
            f.write(f"\n{long_passquery_user_message_content}\n")
            f.write(f"```\n")
        
        system_message_content=f"""
            You are a Conversational Assistant part of InfoRetCo, a chat service company. You are currently having a conversation with a loyal customer. Over this customer's time with the company, we have gathered the following information about them: "{ptkb_statements_string_comma}".

            # Instructions
            
            1. Please generate a 10 sentence article the customer can read to answer their question. I REPEAT: 10 sentences.
            2. The writing style must be Formal and Informative. It must presents information in a factual, concise, and authoritative manner, referencing sources and providing specific details. It uses precise language, making it suitable for educational or professional contexts.
            2. Your response must begin with "Article: ..."
            3. You are not allowed to ask follow up questions to the user. I REPEAT: You are not allowed to ask follow up questions to the user.
                        
            The Conversation History Context is as follows:{conversation_history_for_model_2}
                        
            """


        long_passquery_text = use_llama(
            model=model,
            tokenizer=tokenizer,
            system_message_content=system_message_content,
            user_message_content=long_passquery_user_message_content,
            
            ) 
        
        # clean
        long_passquery_text = long_passquery_text.replace("Article: ", "")
        
        # remove all line breaks
        long_passquery_text = long_passquery_text.replace("\n", " ")
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\nPrompt:")
            f.write(f"\n```")
            f.write(f"\n{system_message_content}\n")
            f.write(f"```")            
            f.write(f"\nOut:")
            f.write(f"\n```")
            f.write(f'\n{textwrap.fill(long_passquery_text, width=100)}\n')
            f.write(f"```")

        
        query_list = [long_passquery_text, short_passquery_text]
        reranking_list = [long_passquery_text, short_passquery_text]


        ptkb_query_weights = [1, 1]
    
    
    # update current utterance to conversation history
    conversation_history = f"{conversation_history}\n- User: {turn['utterance']}"


    return query_list, reranking_list, conversation_history, conversation_history_for_model_2, relevant_ptkb_statements, ptkb_query_weights



## Run

### Run Settings

In [None]:
########################
# RUN SETTINGS
########################

# paths
topic_filename = '2024_test_topics.json'
run_name = 'main_infosense_llama_short_long_qrs'
passage_qrels_filename = 'no_qrel.txt' # '2023-qrels.all-turns.txt' or set to no_qrel.txt to run without qrels. 2023:
ptkb_qrels_filename = 'no_qrel.txt'         # '2023-ptkb-qrels.txt' or set to no_qrel.txt.  2023:
llama_model_choice = 'Meta-Llama-3.1-8B-Instruct' # 'Meta-Llama-3.1-8B-Instruct', 'Meta-Llama-3-70B-Instruct', 'Meta-Llama-3-8B-Instruct'
reranker_mode_choice = ['msmarco-distilbert-base-v4', 'all-MiniLM-L12-v2'] # SentenceTransformers: 'msmarco-distilbert-base-v4', 'msmarco-distilbert-base-tas-b', 'multi-qa-mpnet-base-cos-v1', 'all-mpnet-base-v2', 'paraphrase-mpnet-base-v2',  'all-MiniLM-L12-v2'



# variables
num_docs        = 5000  # number of documents to retrieve
searcher_model = 'bm25'
num_response    = 1  # number of responses to generate
num_passages    = 3  # number of passages to use to generate a responses
cuda_device_num = 1
use_cuda        = True

ptkb_ranker     = 'llama'
query_rewriter  = 'llama'
use_subtree_only = None      # None or a list with specific subtrees, e.g. 2023: ['9-1', '9-2', '10-1']




skip_turns_without_qrels = False

########################
# VARIABLE INITIALIZATION
########################

run_start_timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

# create directory if non exists
directory = f'./output/{run_start_timestamp}/'
if not os.path.exists(directory):
    os.makedirs(directory)

index_path = settings.INDEX_PATH
qrels_filepath = settings.QRELS_PATH
models_path = f'{settings.MODELS_PATH}'
topic_path = f'{settings.TOPICS_PATH}/{topic_filename}'
passage_qrels_filepath = f'{qrels_filepath}/{passage_qrels_filename}'
ptkb_qrels_filepath = f'{qrels_filepath}/{ptkb_qrels_filename}'
output_filename = f'./output/{run_start_timestamp}/{run_name}.json'
verbose_output_filename = f'./output/{run_start_timestamp}/{run_name}_verbose.txt'


### Input data

In [None]:
# Load the data from the specified JSON file
print(f'Loading data from {topic_filename}')
with open(topic_path, 'r') as f:
    data = json.load(f)
print('[Done].')

# subset if a specific subtree is chosen
if use_subtree_only is not None:
    data = [entry for entry in data if entry['number'] in use_subtree_only]
    print(f'subtree chosen: {use_subtree_only}')

# Further filter the turns by specific turn_id
#specific_turn_id = [3]
#for entry in data:
#    entry['turns'] = [turn for turn in entry['turns'] if turn['turn_id'] in specific_turn_id]


# uncomment to see preview
# Limit the number of lines to be printed
#n = 2
#for item in data[:n]:
#    print(json.dumps(item, indent=4))



In [None]:
# Load the data from the specified JSON file
print(f'Loading data from {ptkb_qrels_filename}')
with open(ptkb_qrels_filepath, 'r') as f:
    ptkb_qrels = f.readlines()
print('[Done].')

# uncomment to see preview
#ptkb_qrels[:5]


In [None]:
# Load the data from the specified JSON file
print(f'Loading data from {passage_qrels_filename}')
with open(passage_qrels_filepath, 'r') as f:
    passage_qrels = f.readlines()
print('[Done].')

# uncomment to see preview
#passage_qrels[:5]


### Procedure

In [None]:
# beginning timestamp
beginning_timestamp = datetime.now()


# set the CUDA device if available and required
cuda_device = 'cuda:' + str(cuda_device_num)
device = torch.device(cuda_device if torch.cuda.is_available() and use_cuda else 'cpu')


# load models into a list
reranker_model = []
for model_name in reranker_mode_choice:
    model_path = f'{models_path}/{model_name}'
    model = SentenceTransformer(model_path)
    reranker_model.append(model)


# define the local path to the model directory
local_path = f"{models_path}/{llama_model_choice}/"

# load the tokenizer from the specified local path
llama_tokenizer = AutoTokenizer.from_pretrained(local_path, local_files_only=True)

# load the model from the specified local path
llama_model = AutoModelForCausalLM.from_pretrained(
    local_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # comment out if not enough GPU resources
    local_files_only=True,

)


# initialize the searcher with the specified index
searcher = LuceneSearcher(index_dir=index_path)

# BM25 params
k1 = 1.2 
b = 0.75 
searcher.set_bm25(k1=k1, b=b)



# save info for tracking system progress
with open(verbose_output_filename, 'a') as f:
    f.write(
        f'Beginning Timestamp: {beginning_timestamp.strftime("%Y-%m-%d %H:%M:%S")}\n\n')
    f.write(f'Params:\n')
    f.write(f'Using device: {device}\n')
    f.write(f'Run Type ==> automatic\n')
    f.write(f'Number of responses to generate ==> {num_response}\n')
    f.write(
        f'Number of passages to use for response generation ==> {num_passages}\n')
    f.write(f'PTKB ranking function ==> {ptkb_ranker}\n')
    f.write(f'Retrieval Model ==> {searcher_model}\n')
    f.write(f'Reranker Model ==> {reranker_mode_choice}\n')
    f.write(f"\n{'#'*100}\n\n")


# open the result JSON file
with open(output_filename, 'a') as f:
    f.write("{\n")
    f.write(f"""    "run_name": "{run_name}",\n""")
    f.write(f"""    "run_type": "automatic",\n""")
    f.write(f"""    "eval_response": true,\n""")
    f.write(f"""    "turns": [\n""")


# initialize variables
num_turns = 0
num_subtrees = len(data)
total_found_passages = 0
total_total_passages_expected = 0

# df for tracking results
full_results_table = []


for idx, d in enumerate(data):
    

    print(
        f'========PROCESSING CONVERSATION-{idx + 1} OF {num_subtrees}==============')

    # save all the prints to a text file for tracking system progress
    with open(verbose_output_filename, 'a') as f:
        f.write(
            f'\n========PROCESSING CONVERSATION-{idx + 1} OF {num_subtrees}==============\n')

    number = str(d['number']) # convert to string to handle 2024 topics
    turns = d['turns']
    ptkb = d['ptkb']


    # initialize variables that need to reset at every conversation
    conversation_history = ""
    previous_turn = ""
    turns_in_subtree = len(turns)
    conversation_found_passages = 0
    conversation_total_passages_expected = 0

    # update the total number of turns
    num_turns += len(turns)


    # check condition if no qrels for entire run
    conv_has_qrels = has_qrels_for_conversation(passage_qrels, number)
    if skip_turns_without_qrels == True and not conv_has_qrels: 
        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n### Skipping {number} for no qrels in entire conversation:")
            
        # remove last ',\n' comma from jason
        with open(output_filename, 'r') as f:
            content = f.read()
        if content.endswith(",\n"):
            content = content[:-2]        
        with open(output_filename, 'w') as f:
            f.write(content)

        continue # go to next conversation
    
    
    
    # save all the prints to a text file for tracking system progress
    with open(verbose_output_filename, 'a') as f:
        f.write('\nSTART CONVERSATION\n')

    # process each turn in the conversation
    for turn in tqdm(turns, total=len(turns)):
        
        
        # turn level vars
        turn_id = turn['turn_id']
        query_id = number + '_' + str(turn_id)
        
        turn_found_passages = 0
        turn_total_passages_expected = 0

        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n{'#'*100}\n")
            f.write(f"topic-subtree: {number}\n")
            f.write(f"turn: {turn['turn_id']} of {turns_in_subtree}\n")
            f.write(f"{'#'*100}\n")

        
        # save all the prints to a text file for tracking system progress
        with open(verbose_output_filename, 'a') as f:
            # users utters
            f.write(f"\n### User utters:")
            f.write(f"\n```")
            f.write(f"\n{textwrap.fill(turn['utterance'], width=100)}\n")
            f.write(f"```\n")



        # see if you got all the documents in the qrels ...
        
        # get qrels
        expected_qrel_passages = extract_relevant_passages(passage_qrels=passage_qrels, conversation_id=query_id, verbose_output_filename=verbose_output_filename)
       
        ##############
        # generate query
        #############
        
        query_list, reranking_list, conversation_history, conversation_history_for_model, relevant_ptkb_statements, ptkb_query_weights = get_query_llama(
            ptkb=ptkb,
            turn=turn,
            model=llama_model,
            tokenizer=llama_tokenizer,
            conversation_history=conversation_history,
            ptkb_qrels=ptkb_qrels,
            query_id=query_id,
            previous_turn=previous_turn,
            skip_turns_without_qrels=skip_turns_without_qrels,
            expected_qrel_passages=expected_qrel_passages,
        )


        
        # skip based on settings
        if skip_turns_without_qrels == True and not expected_qrel_passages:
            
            # save all the prints to a text file for tracking system progress
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n{'#'*100}\n")
                f.write(f"This turn has no qrels. Skipping the search task...\n")
                f.write(f"{'#'*100}\n")
                
            # Remove the last character in the file (the comma)
            if turns_in_subtree == turn_id and (idx + 1) == num_subtrees:
                with open(output_filename, 'r') as f:
                    content = f.read()

                # Check if the content ends with ",\n"
                if content.endswith(",\n"):
                    # Remove the last 2 characters (",\n")
                    content = content[:-2]

                # Open the file in write mode to overwrite the content
                with open(output_filename, 'w') as f:
                    f.write(content)

            
            previous_turn = turn
            
            
            
            # Separate the run in verbose output
            with open(verbose_output_filename, 'a') as f:
                
                # ending timestamps
                midrun_timestamp = datetime.now()
                total_time = midrun_timestamp - beginning_timestamp
                f.write(f'\nTimestamp: {midrun_timestamp.strftime("%Y-%m-%d %H:%M:%S")}\n')
                f.write(f'Time Elapsed: {total_time}\n')        
                f.write(F"\nNext turn ...\n\n\n")
                f.write(F"{'|'*100}\n")
                f.write(F"{'|'*100}\n")
                f.write(F"{'|'*100}\n\n")
                
            continue # go to next conversation
        
        else:   
            
            #########################
            # quick retrieve documents
            ##########################

            # save all the prints to a text file for tracking system progress
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n\nRetrieval Task\n")
                f.write(F"{'#'*100}\n")

            preliminary_hits, query_source_counts, query_source_docs = combine_hits(query_list, query_id, num_docs, searcher, verbose_output_filename)

            # save all the prints to a text file for tracking system progress
            with open(verbose_output_filename, 'a') as f:
                f.write(f'\n{len(preliminary_hits):,} documents retrieved.\n')

            ######################
            # check success
            ######################

            if not expected_qrel_passages:  # if empty list
                # save all the prints to a text file for tracking system progress
                with open(verbose_output_filename, 'a') as f:
                    f.write(f'\nNo passage qrels for turn {query_id}. Moving on ...\n')
            else:
                
                # check retrieval success of original
                turn_found_passages, turn_total_passages_expected = check_quick_retrieval_success(preliminary_hits, expected_qrel_passages, query_source_counts, query_source_docs, verbose_output_filename)
            
                
             
                
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n\nReranking Task\n")
                f.write(F"{'#'*100}\n")
                f.write(f'\n## Reranking the documents...\n')


            reranked_passages = passage_neural_retrieval(
                preliminary_hits=preliminary_hits, 
                reranking_list=reranking_list, 
                verbose_output_filename=verbose_output_filename, 
                model=reranker_model, 
                query_weights=ptkb_query_weights, 
                top_k=1000,
                )
            
            # save all the prints to a text file for tracking system progress
            with open(verbose_output_filename, 'a') as f:
                f.write(f'\n## Reranking done.\n')
                f.write(f"\n\nResponse Generation Task\n")
                f.write(F"{'#'*100}\n")
                f.write(f'\n## Generating the response...\n')

            # generate reponses pre-reranking
            sorted_responses_prereranking, sorted_responses_provenance_prereranking = generate_top_response(
                conversation_history_for_model=conversation_history_for_model,
                relevant_ptkbs=relevant_ptkb_statements,
                reranked_passages=preliminary_hits,
                searcher=searcher,
                num_passages=num_passages,
                summarizer=llama_model,
                tokenizer=llama_tokenizer,
                current_utterance=turn['utterance'],
                pre_reranking=True,
            )
            # generate reponses
            sorted_responses, sorted_responses_provenance = generate_top_response(
                conversation_history_for_model=conversation_history_for_model,
                relevant_ptkbs=relevant_ptkb_statements,
                reranked_passages=reranked_passages,
                searcher=searcher,
                num_passages=num_passages,
                summarizer=llama_model,
                current_utterance=turn['utterance'],
                tokenizer=llama_tokenizer,
            )
            
            # Calculate current turn performance
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n\nExpected response\n")
                f.write('```\n')
                f.write(F"{turn['response']}\n")
                f.write('```\n')

            #########################
            # performance measurement
            #########################

            # NOTE: this pre-rank variable is just for ranking performance measurement,
            # hence the sorted responses and their provenenance correspond to the post reranking.
            # Don't use this variable for anything else as the it will lead to wrong results
            turn_json_pre_rerank = prepare_output_for_json(
                turn_id=query_id,
                sorted_responses=sorted_responses_prereranking,
                sorted_responses_provenance=sorted_responses_provenance_prereranking,
                reranked_passages=preliminary_hits,            # Only this matters here
                searcher=searcher,
                relevant_ptkbs=relevant_ptkb_statements
            )

            # add to the JSON output
            turn_json_post_rerank = prepare_output_for_json(
                turn_id=query_id,                           
                sorted_responses=sorted_responses,          
                sorted_responses_provenance=sorted_responses_provenance,
                reranked_passages=reranked_passages,        
                searcher=searcher,                          
                relevant_ptkbs=relevant_ptkb_statements
            )

            # Calculate current turn performance
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n\nReranking task performance measurement\n")
                f.write(F"{'#'*100}\n")
                # f.write(f'\n### Calculating the current turn reranking task performance ...\n')

            if not expected_qrel_passages:  # if empty list
                with open(verbose_output_filename, 'a') as f:
                    f.write(f'\nNo passage qrels for turn {query_id}. Moving on ...\n')

            else:

                turn_result_evaluations_pre = measure_intrarun_ranking_performance(turn_json_pre_rerank, run_name, output_filename, passage_qrels_filename)
                turn_result_evaluations_post = measure_intrarun_ranking_performance(turn_json_post_rerank, run_name, output_filename, passage_qrels_filename)

                combined_metrics = combine_performance_metrics(turn_result_evaluations_pre, turn_result_evaluations_post)

                # Save to output file
                with open(verbose_output_filename, 'a') as f:
                    f.write(f'\n## Combined Performance Metrics\n')
                    f.write(F"{'-'*100}\n\n")
                    f.write(combined_metrics)

            
            
            
            # Update variables
            previous_turn = turn
            conversation_found_passages += turn_found_passages
            conversation_total_passages_expected += turn_total_passages_expected

            
            # save all the prints to a text file for tracking system progress
            with open(verbose_output_filename, 'a') as f:
                f.write(f"\n\nHousekeeping\n")
                f.write(F"{'#'*100}\n")
                f.write(f'\nUpdating the run JSON file ...\n')

            
            # Add 8 spaces outside the `with open` line
            json_str = json.dumps(turn_json_post_rerank, indent=4)
            indented_json_str = '\n'.join('        ' + line for line in json_str.splitlines())
            
            
            # Write turn to JSON
            with open(output_filename, 'a') as f:
                # Dont include the comma if last turn of the last subtree
                if turns_in_subtree == turn_id and (idx + 1) == num_subtrees:
                    f.write(indented_json_str + '\n')
                else:
                    f.write(indented_json_str + ',\n')

        # Separate the run in verbose output
        with open(verbose_output_filename, 'a') as f:
            
            # ending timestamps
            midrun_timestamp = datetime.now()
            total_time = midrun_timestamp - beginning_timestamp
            f.write(f'\nTimestamp: {midrun_timestamp.strftime("%Y-%m-%d %H:%M:%S")}\n')
            f.write(f'Time Elapsed: {total_time}\n')        
            f.write(F"\nNext turn ...\n\n\n")
            f.write(F"{'|'*100}\n")
            f.write(F"{'|'*100}\n")
            f.write(F"{'|'*100}\n\n")

    
    
    # update conversation level vars
    total_found_passages += conversation_found_passages
    total_total_passages_expected += conversation_total_passages_expected





    # conversation level reporting
    ########################
    
    if not conv_has_qrels:  # if empty list
        with open(verbose_output_filename, 'a') as f:
            f.write(f'\nNo conversation-level qrels for turn {number}. Moving on ...\n')

    else:
    
    
        
        # Separate the run in verbose output
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\nConversation-level performance\n")
            f.write(F"{'#'*100}\n\n")
            f.write(f'Total correct passages found in conversation: {conversation_found_passages}\n')
            f.write(f'Total correct passages expected  in conversation: {conversation_total_passages_expected}\n')
            
            
            
            # avoid division by zero when no qrels for topic
            if conversation_total_passages_expected != 0:
                conversation_found_passages_fraction = conversation_found_passages / conversation_total_passages_expected
            else:
                conversation_found_passages_fraction = 0
            
            
            f.write(f'Fraction found: {conversation_found_passages_fraction:,}\n\n')

            

        # open the json file and format for computing performance
        with open(output_filename, 'r') as file:
            json_data = file.read()
            # find the position of the last closing brace ('}')
            last_brace_index = json_data.rfind('}')
            json_data = json_data[:last_brace_index]
            json_data = json_data + "}]}" # complete it
            partial_json = json.loads(json_data)


        # compute
        conversation_turn_json = select_turns_by_prefix(partial_json, number)
            
        conversation_performance_metrics = measure_intrarun_ranking_performance(conversation_turn_json, run_name, output_filename, passage_qrels_filename, section_type = "conversation")

            # Save to output file
        with open(verbose_output_filename, 'a') as f:
            f.write(f'{conversation_performance_metrics}\n\n')
            
        # tracker table
        full_results_evaluations = extract_evaluations(conversation_performance_metrics)
            
        full_results_table.append(
            {
                'topic_subtree': number, 
                'found_passages_fraction': conversation_found_passages_fraction,
                'ndcg_3': full_results_evaluations["ndcg_cut_3"], 
                'ndcg_5': full_results_evaluations["ndcg_cut_5"], 
                'ndcg': full_results_evaluations["ndcg"], 
                'p_20': full_results_evaluations["P_20"],
                'recall_20': full_results_evaluations["recall_20"], 
                'map': full_results_evaluations["map"],
            })
        
        
        # separate the run in verbose output
        with open(verbose_output_filename, 'a') as f:
            f.write(f"\n\nPerformance Summary\n\n")
            f.write(f'{pd.DataFrame(full_results_table).to_string(index=False)}\n\n')

        

# close the result JSON file
with open(output_filename, 'a') as f:
    f.write(f"""    ]\n""")
    f.write("}\n")



########################
# run level reporting
########################

if passage_qrels_filename != 'no_qrel.txt':
        
    # ending timestamps
    ending_timestamp = datetime.now()
    total_time = ending_timestamp - beginning_timestamp

    # Separate run in verbose output
    with open(verbose_output_filename, 'a') as f:
        f.write(F"\nDone\n")
        f.write(f'Ending Timestamp: {ending_timestamp.strftime("%Y-%m-%d %H:%M:%S")}\n')
        f.write(f'Total Time: {total_time}\n')
        
    # Separate the run in verbose output
    with open(verbose_output_filename, 'a') as f:
        f.write(f"\n\nRun-level performance\n")
        f.write(F"{'#'*100}\n\n")
        f.write(f'Total correct passages found in run: {total_found_passages}\n')
        f.write(f'Total correct passages expected in run: {total_total_passages_expected}\n')
        
        if total_total_passages_expected != 0:
            total_found_passages_fraction = total_found_passages / total_total_passages_expected
        else:
            total_found_passages_fraction = 0
        
        f.write(f'Fraction found: {total_found_passages_fraction:,}\n')


    # Convert passage output in JSON to TRECRun format
    run_generate_run(output_filename)


    # full results    
    evaluations_ndcg_cut_3 = run_trec_eval(passage_qrels_filename, f"{output_filename.replace('.json', '.json.run')}", metric_list=['ndcg_cut.3'])
    evaluations = run_trec_eval(passage_qrels_filename, f"{output_filename.replace('.json', '.json.run')}", metric_list=['ndcg', 'ndcg_cut', 'P', 'recall', 'map'])

    evaluations = full_evaluations_join(evaluations_ndcg_cut_3, evaluations)


    # Save to verbose output
    with open(verbose_output_filename, 'a') as f:
        f.write(f'\n\nOverall Passage Ranking Task Performance\n')
        f.write(evaluations)
        

