# Evolutionary Prompt Selection
## For Planner/Worker/Solver Framework

### Install Dependencies

In [None]:
!printf 'accelerate\nbitsandbytes\ndatasets\npinecone-client[grpc]\nsentencepiece\nsentence-transformers\ntorch\ntransformers\nwikipedia ' > requirements.txt  
!pip install -r requirements.txt


### Import Statements

In [None]:
import json
import math
import os
import re
import string
import time

import pandas as pd
import pinecone
import torch
import wikipedia

from threading import Thread, Lock

from datasets import load_dataset
from numpy.random import choice
from sentence_transformers import SentenceTransformer
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
from tqdm.auto import tqdm

from kaggle_secrets import UserSecretsClient


### Define Prompt Prefixes/Suffixes

In [None]:
PLANNER_PROMPT = {"prefix": "You are an advanced AI capable of making plans to solve complex problems. For the following tasks, make plans that can solve the problem step-by-step. For each plan, indicate which external tool together with tool input to retrieve evidence. You can store the evidence into a variable #E that can be called by later tools. (Plan, #E1, Plan, #E2, Plan, ...)","suffix": "Describe your plans with rich details. Each Plan should be followed by only one #E. Answer each question directly with plans."}
SOLVER_PROMPT = {"prefix": "You are an advanced AI capable of solving tasks based on evidence. Solve the following task or problem based on the provided plans and corresponding evidences. Keep your responses direct and concise.", "suffix": "Based on the provided evidence answer the following question directly and concisely: "}
TOOLS_PROMPT = {"Wikipedia": "Worker that search for similar page contents from Wikipedia. Useful when you need to get holistic knowledge about people, places, companies, historical events, or other subjects. The response are long and might contain some irrelevant information. Input should be a search query.", "LLM": "A pretrained LLM like yourself. Useful when you need to act with general world knowledge and common sense. Prioritize it when you are confident in solving the problem yourself. Input can be any instruction."}
EXTRACTOR_PROMPT = {"prefix": "Based on the given statement, concisely extract the answer to the following question. Respond directly with the concise answer to the question."}


### Define classes and functions

#### Nodes

In [None]:
class Node:
    """Basic node class"""
    def __init__(self):
        raise NotImplementedError

    def run (self, inputs):
        raise NotImplementedError


class LLMNode(Node):
    """A node that is based on an LLM"""
    def __init__(self, model):
        self.model = model
        self.system_tag = model.system_tag
        self.user_tag = model.user_tag
        self.ai_tag = model.ai_tag
        self.stops = ['.', '\n']

    def call_llm(self, prompt):
        """Calls the underlying LLM with the given inputs
        Parameters:
        ------------
        prompt: str
            prompt for the LLM

        Returns:
        ------------
        response: str
            LLM response
        """
        response = self.model.generate(prompt, self.stops)
        return response


class Planner(LLMNode):
    """Planner node for making plans within the PWS framework"""
    def __init__(self, model):
        super().__init__(model)
        self.stops = ['\n\n']
        self.prefix = PLANNER_PROMPT['prefix']
        self.suffix = PLANNER_PROMPT['suffix']
        self.tools = TOOLS_PROMPT

    def run(self, task, examples):
        """Generate plans for the given task, examples and tools
        Parameters:
        ------------
        task: str
            Task for which the plan is to be generated
        examples: list(dict)
            Examples related to the task for the fewshot prompt

        Returns:
        ------------
        planner_response: dict(str:obj)
            Planner response contains the plans and the evidences
        """
        prompt = self.generate_prompt(task, examples)
        response = self.call_llm(prompt)
        plans, tool_calls = self.parse_response(response)
        planner_response = {'plans': plans, 'tool_calls': tool_calls,
                            'text':response}
        return planner_response

    def generate_prompt(self, task, examples):
        """Generates a planner prompt for the given task, examples and tools
        Parameters:
        ------------
        task: str
            Task for which the plan is to be generated
        examples: list(dict)
            Examples related to the task for the fewshot prompt

        Returns:
        ------------
        prompt: str
            planner prompt
        """
        tools = {tool: self.tools[tool] for example in examples
                 for tool in example['tools']}

        prompt = f"{self.system_tag}{self.prefix}\n"
        prompt += "Tools can be one of the following:\n"
        for tool, description in tools.items():
            prompt += f"{tool}[input]: {description}\n"
        prompt += f"{self.suffix}\n\n"
        for example in examples:
            prompt += f"{self.user_tag}{example['question'].strip()}\n\n"
            prompt += f"{self.ai_tag}{example['plan'].strip()}\n\n"
        prompt += f"{self.user_tag}{task.strip()}\n\n"
        prompt += self.ai_tag
        return prompt

    def parse_response(self, response):
        """Parse the planner response and return plans and evidences dictionary
        Parameters:
        ------------
        response: str
            Planner response

        Returns:
        ------------
        plans: list(str)
            List that contains the plans
        evidences: dict(str:str)
            Evidence dict conatining evidences and associated tool calls
        """
        plans = []
        tool_calls = {}
        for line in response.splitlines():
            if line.startswith("Plan:"):
                plans.append(line)
            elif len(line) < 3:
                continue
            elif line.startswith("#") and line[1] == "E" and line[2].isdigit():
                e, tool_call = line.split("=", 1)
                e, tool_call = e.strip(), tool_call.strip()
                if len(e) == 3:
                    tool_calls[e] = tool_call
                else:
                    tool_calls[e] = "No evidence found"
        return plans, tool_calls


class WikipediaWorker(Node):
    """Worker that searches Wikipedia"""
    def __init__(self):
        pass

    def run(self, inputs):
        """Searches Wikipedia for the given inputs and returns the first
        2000 characters of the first page in the search results
        Parameters:
        ------------
        inputs: str
            String input for Wikipedia search

        Returns:
        ------------
        evidence: str
            First paragraph of the first page from the search results
        """
        evidence = "No evidence found."
        pages = wikipedia.search(inputs[:300], results=1)
        if pages:
            try:
                evidence = wikipedia.page(pages[0], auto_suggest=False).content
                evidence = evidence[:2000]
            except:
                pass

        return evidence


class LLMWorker(LLMNode):
    """LLM node to be used for worker calls"""
    def run(self, inputs):
        """Run the LLM as a tool call
        Parameters:
        ------------
        inputs: str
            Input for the tool call

        Returns:
        ------------
        evidence: str
            Cleaned response from the tool call
        """
        # Truncate input if necessary
        tokens = self.model.tokenizer(inputs)['input_ids']
        if len(tokens) > 2000:
            inputs = self.model.tokenizer.decode(tokens[:2000],
                                                 skip_special_tokens=True)
        prompt = self.system_tag
        prompt += "Directly answer the following question with no extra words."
        prompt += f"\n\n{self.user_tag}{inputs.strip()}\n\n{self.ai_tag}"
        response = self.call_llm(prompt)
        evidence = response.strip()
        return evidence


class Worker(Node):
    """Worker node that calls appropriate workers for each tool call"""
    def __init__(self, model):
        self.wiki_worker = WikipediaWorker()
        self.llm_worker = LLMWorker(model)

    def run(self, inputs):
        """Faciliates all tool calls and returns evidences
        Parameters:
        ------------
        inputs: dict(str:str)
            A dictionary of evidence variables and associated tool calls

        Returns:
        ------------
        evidences: dict(str:str)
            A dictinary of evidence variables and the outputs of the associated
            tool calls
        """
        evidences = {}
        for e, tool_call in inputs.items():
            # Do not process tools without input
            if "[" not in tool_call:
                evidences[e] = tool_call
                continue

            # Seperate tool and tool input
            tool, tool_input = tool_call.split("[", 1)
            tool_input = tool_input[:-1]

            # Find variables in input and replace with previous evidences
            for var in re.findall(r"#E\d+", tool_input):
                if var in evidences:
                    try:
                        evidence = evidences[var]
                    except KeyError:
                        evidence = "No evidence found."
                    tool_input = tool_input.replace(var, f"[{evidence}]")

            match tool:
                case "Wikipedia":
                    evidences[e] = self.wiki_worker.run(tool_input)
                case "LLM":
                    evidences[e] = self.llm_worker.run(tool_input)
                case _:
                    evidences[e] = "No evidence found."

        return evidences


class Solver(LLMNode):
    """Solver node that solves tasks for given plans and evidences"""
    def __init__(self, model):
        super().__init__(model)
        self.prefix = SOLVER_PROMPT['prefix']
        self.suffix = SOLVER_PROMPT['suffix']

    def run(self, task, plans, evidences):
        """Solve the task based on the given plans and evidences
        Parameters:
        ------------
        task: str
            Task to be solved
        plans: list(str)
            List of plans generated by Planner
        evidences: dict(str:str)
            Dictionary of evidences generated by the Worker

        Returns:
        ------------
        output: str
            Solution generated based on the given plans and evidences
        """
        prompt = f"{self.system_tag}{self.prefix}\n\n"
        prompt += f"{self.user_tag}{task.strip()}\n"
        for i in range(len(plans)):
            e = f"#E{i + 1}"
            plan = plans[i]
            try:
                evidence = evidences[e]
            except KeyError:
                evidence = "No evidence found."
            # Only include the first 500 characters of each evidence
            prompt += f"{plan}\nEvidence: {evidence[:500]}...\n"
        prompt += f"{self.suffix + task.strip()}\n\n{self.ai_tag}"
        output = self.call_llm(prompt)
        return output

class Extractor(LLMNode):
    def __init__(self, model):
        super().__init__(model)
        self.prefix = EXTRACTOR_PROMPT['prefix']

    def __call__(self, statement, question):
        prompt = f"{self.system_tag}{self.prefix}\n"
        prompt += f"{self.user_tag}Statement: {statement}\n"
        prompt += f"Question: {question}\n{self.ai_tag}"
        output = self.call_llm(prompt)
        return output


#### Utils

In [None]:
class MultiTokenEOSCriteria(StoppingCriteria):
    """Stopping criteria based on a given multi-token sequence.
    Please refer to HuggingFace Transformers library for documentation"""

    def __init__(self, sequence, tokenizer, initial_decoder_input_length):
        self.initial_decoder_input_length = initial_decoder_input_length
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence,
                                             add_special_tokens=False)
        self.sequence_id_len = len(self.sequence_ids)
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number
        # of tokens in the stop_sequence
        lookback_ids = input_ids[0][self.initial_decoder_input_length:]
        lookback_ids = lookback_ids[-self.sequence_id_len:]
        lookback_tokens = self.tokenizer.decode(lookback_ids)
        return self.sequence in lookback_tokens


class LanguageModel:
    """Language model wrapper to be used in nodes"""
    def __init__(self, model_path, generation_config,
                 device_map='auto', load_in_8bit=False, access_token=None,
                 system_tag='\n', user_tag='\n', ai_tag='\n'):
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path,
            use_auth_token=access_token)
        self.model = LlamaForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.float16, device_map=device_map,
            load_in_8bit=load_in_8bit, use_auth_token=access_token)
        self.generation_config = generation_config
        if device_map == 'auto':
            self.device = 'cuda'
        else:
            self.device = f'cuda:{device_map}'
        self.system_tag = system_tag
        self.user_tag = user_tag
        self.ai_tag = ai_tag

    def stop_sequences_criteria(self, stop_sequences,
                                initial_decoder_input_length):
        """Creates a custom stopping criteria for the given input
        Parameters:
        ------------
        stop_sequences: list(str)
            A list of strings that ends text generation
        initial_decoder_input_length: int
            Total number of tokens in the initial input

        Returns:
        ------------
            StoppingCriteriaList object
        """
        return StoppingCriteriaList(
            [
                MultiTokenEOSCriteria(sequence, self.tokenizer,
                                      initial_decoder_input_length)
                for sequence in stop_sequences
            ]
        )

    def generate(self, prompt, stops):
        """Generate text based on given prompt
        Parameters:
        ------------
        prompt: str
            Prompt for the LLM
        stops: list(str)
            List of strings to be used as stopping criteria

        Returns:
        ------------
        output_text: str
            LLM generated response
        """
        input_tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        input_length = input_tokens['input_ids'].shape[1]
        stopping_criteria = self.stop_sequences_criteria(stops, input_length)
        with torch.no_grad():
            output_tokens = self.model.generate(
                **input_tokens,
                generation_config=self.generation_config,
                stopping_criteria=stopping_criteria
                )

        output_text = self.tokenizer.decode(output_tokens[0][input_length:],
                                            skip_special_tokens=True)

        return output_text


class PWS:
    """ Planner Worker Solver Framework"""
    def __init__(self, model):
        self.planner = Planner(model=model)
        self.worker = Worker(model=model)
        self.solver = Solver(model=model)

    def run(self, task, examples, verbose=False):
        """Run the PWS on a given task based on provided examples
        Parameters:
        ------------
        task: str
            Task for which the PWS is to be run
        examples: list(str)
            Examples related to the task for the fewshot prompt
        verbose: bool, default=False
            If True, responses from intermediate nodes are also returned

        Returns:
        ------------
        pws_response: dict(str:obj)
            PWS response contains the output and time elapsed
            If verbose responses from intermediate nodes are also returned
        """

        st = time.time()
        # Plan
        planner_response = self.planner.run(task, examples)
        plans = planner_response["plans"]
        tool_calls = planner_response["tool_calls"]

        # Work
        evidences = self.worker.run(tool_calls)

        # Solve
        output = self.solver.run(task, plans, evidences)

        wall_time = time.time() - st

        pws_response = {"output": output,
                        "wall_time": wall_time}

        if verbose:
            pws_response["planner_response"] = planner_response
            pws_response["worker_response"] = evidences

        return pws_response


class EPS:
    """ Evolutionary Prompt Selection"""
    def __init__(self, index, embedding_model,
                 similar_pool_size=5, instructive_pool_size=5):
        self.index = index
        index_stats = self.index.describe_index_stats()
        self.index_size = index_stats['total_vector_count']
        self.embedding_model = embedding_model
        self.similar_pool_size = similar_pool_size
        self.instructive_pool_size = instructive_pool_size
        self.most_instructive = []
        self.set_most_instructive()

    def set_most_instructive(self):
        """Retrieve the most instructive examples from the index
        Pinecone does not support aggregations over metadata so we fetch all
        instructions and manually select the most instructive ones
        """
        batch_size = 1000
        score = lambda entry: entry['metadata']['score']
        for i in range(0, self.index_size, batch_size):
            # Find end of batch
            i_end = min(i+batch_size, self.index_size)
            # Create IDs batch
            ids = [str(idx) for idx in range(i, i_end)]
            batch = list(self.index.fetch(ids)['vectors'].values())
            # Sort and keep the most instructive
            batch_sorted = sorted(batch + self.most_instructive,
                                  key=score, reverse=True)
            self.most_instructive = batch_sorted[:self.instructive_pool_size]

    def select_examples(self, task, num_examples=3):
        """Select instructive examples based on a given task
        This method samples instructions from a curated pool of examples
        The pool is curated by combinining (similar_pool_size) number of
        semantically similar examples and (instructive_pool_size) number of
        examples with high instruction score
        The examples are then sampled based on their combined example score:
        ((semantic similarity + 1.0) * (log(instruction_score) + 1.0)
        Parameters:
        ------------
        task: str
            Task for which the instructive examples are to be selected
        nun_examples: int, default=3
            Number of instructive examples to return

        Returns:
        ------------
        examples: list(str)
            List of instructive examples relevant to the task
        """
        task_embedding = self.embedding_model.encode(task,
            show_progress_bar=False).tolist()
        most_similar = self.index.query(task_embedding,
                                        top_k=self.similar_pool_size,
                                        include_metadata=True)['matches']
        instructive_ids = [entry['metadata']['id']
                           for entry in self.most_instructive]
        most_instructive = self.index.query(task_embedding,
            top_k=self.instructive_pool_size,
            filter={'id':{"$in": instructive_ids}},
            include_metadata=True)['matches']
        pool = most_similar + most_instructive
        weights = [(entry['score'] + 1.0) * (math.log(
            entry['metadata']['score']) + 1.0) for entry in pool]
        probabilities = list(map(lambda weight: weight/sum(weights), weights))
        sample_ids = choice(range(len(pool)), num_examples,
                            replace=False, p=probabilities)
        examples = [pool[i] for i in sample_ids]
        return examples

    def increment_score(self, entry_id):
        """Increment the instruction score of an example
        Parameters:
        ------------
        entry_id: str
            Score of the example with the given entry_id is incremented
        """
        score = lambda entry: entry['metadata']['score']
        entry = None
        # Check if the entry is in the most instructive pool
        for candidate in self.most_instructive:
            if candidate['id'] == entry_id:
                entry = candidate
        # If the entry is not in the most instructive pool
        if not entry:
            # Fetch it from the index 
            entry = self.index.fetch([entry_id])['vectors'][entry_id]
            # Add the entry to the most instructive pool
            self.most_instructive.append(entry)
        # Update entry score   
        entry['metadata']['score'] += 1
        self.index.update(id=entry['id'], set_metadata={"score": score(entry)})
        # Sort the most instructive pool and keep the most instructive
        self.most_instructive = sorted(self.most_instructive,
                                       key=score, reverse=True)
        self.most_instructive = self.most_instructive[
            :self.instructive_pool_size]

    def upsert_entry(self, metadata):
        """Upsert a new entry into the index
        Parameters:
        ------------
        metadata: dict
            a dictionary containing entry metadata
            {question, plan, tools, dataset_name}
        """
        entry_id = self.index_size
        embedding = self.embedding_model.encode(metadata['question'],
                                                show_progress_bar=False).tolist()
        metadata['id'] = entry_id
        metadata['score'] = 1
        self.index.upsert(zip([str(entry_id)], [embedding], [metadata]))
        self.index_size += 1


### Test the system

#### Define variables

In [None]:
user_secrets = UserSecretsClient()
PINECONE_API_KEY = user_secrets.get_secret('PINECONE_API_KEY')
PINECONE_ENV = user_secrets.get_secret('PINECONE_ENVIRONMENT')
INDEX_NAME = 'plans'

EMBEDDING_MODEL = 'all-MiniLM-L6-v2'

MODEL_PATH = "stabilityai/StableBeluga-7B"
SYSTEM_TAG = "### System:\n"
USER_TAG = "### User:\n"
AI_TAG = "### Assistant:\n"

LOAD_IN_8BIT = True
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")

TEMPERATURE = 0.01
TOP_K = 50
TOP_P = 0.9
REPETITION_PENALTY= 1.0
MAX_NEW_TOKENS = 256
DEVICE_COUNT = 'auto'

DATASET_NAME = "trivia_qa"

SIMILAR_POOL_SIZE = 5
INSTRUCTIVE_POOL_SIZE = 5
NUM_EXAMPLES = 3



#### Initialize models and prepare the data

In [None]:
generation_config = GenerationConfig(
    do_sample=True,
    temperature=TEMPERATURE,
    top_k=TOP_K,
    top_p=TOP_P,
    repetition_penalty=REPETITION_PENALTY,
    max_new_tokens=MAX_NEW_TOKENS
)


In [None]:
model = LanguageModel(MODEL_PATH, generation_config=generation_config,
                      load_in_8bit=LOAD_IN_8BIT, access_token=HF_TOKEN,
                      system_tag=SYSTEM_TAG, user_tag=USER_TAG, ai_tag=AI_TAG)
agent = PWS(model)
extractor = Extractor(model)


In [None]:
pinecone.init(
    api_key=PINECONE_API_KEY,
    environment=PINECONE_ENV
)
index = pinecone.GRPCIndex(INDEX_NAME)

embedding_model = SentenceTransformer(EMBEDDING_MODEL)

prompter = EPS(index, embedding_model, SIMILAR_POOL_SIZE, INSTRUCTIVE_POOL_SIZE)

dataset = load_dataset(DATASET_NAME, 'rc.nocontext')

sanitize = lambda text: text.strip().lower().translate(str.maketrans('', '', string.punctuation))


#### Hyperparameter Optimization

In [None]:
hp_opt_size = 20
hp_opt_data = dataset['train'][:hp_opt_size]
temp_values = [0.01, 0.25, 0.5, 0.75, 1.0]
rep_values = [1.0, 1.1, 1.2, 1.3]
results = []
for temp in temp_values:
    for rep in rep_values:
        generation_config = GenerationConfig(
            do_sample=True,
            temperature=temp,
            top_k=TOP_K,
            top_p=TOP_P,
            repetition_penalty=rep,
            max_new_tokens=MAX_NEW_TOKENS
        )
        model.generation_config = generation_config
        em = []
        for question, answer in tqdm(zip(hp_opt_data['question'], hp_opt_data['answer']),
                                     total=hp_opt_size):
            list_of_candidates = [sanitize(alias) for alias in answer["aliases"]]

            selection = prompter.select_examples(question, NUM_EXAMPLES)
            examples = [entry['metadata'] for entry in selection]
            response = agent.run(question, examples)
            answer = sanitize(response['output'])

            if answer not in list_of_candidates:
                extracted = sanitize(extractor(response['output'], question))
                if extracted not in list_of_candidates:
                    em.append(False)
                    continue 
            em.append(True)

        print(f"Temperature: {temp}\nRepetition Penalty: {rep}\nScore: {sum(em)}\n")
        results.append({'temp':temp, 'rep':rep, 'score':sum(em)})
    
with open("results.json", "w") as f:
    json.dump(results, f)
    

In [None]:
with open('results.json', 'r') as f:
    results = json.load(f)
    
results_df = pd.DataFrame.from_dict(results)
results_df = results_df.rename(columns={'temp': 'Temperature',
                                        'rep': 'Repetition Penalty',
                                        'score': 'EM Score'})
results_df = results_df.pivot(index='Temperature', columns='Repetition Penalty')
results_df

#### Hyperparameter Optimization Round 2

In [None]:
hp_opt_size = 50
hp_opt_data = dataset['train'][:hp_opt_size]
configs = [(0.01, 1.0), (0.5, 1.1)]
results_v2 = []
for config in configs:
    temp = config[0]
    rep = config[1]
    generation_config = GenerationConfig(
        do_sample=True,
        temperature=temp,
        top_k=TOP_K,
        top_p=TOP_P,
        repetition_penalty=rep,
        max_new_tokens=MAX_NEW_TOKENS
    )
    model.generation_config = generation_config
    em = []
    for question, answer in tqdm(zip(hp_opt_data['question'], hp_opt_data['answer']),
                                 total=hp_opt_size):
        list_of_candidates = [sanitize(alias) for alias in answer["aliases"]]

        selection = prompter.select_examples(question, NUM_EXAMPLES)
        examples = [entry['metadata'] for entry in selection]
        response = agent.run(question, examples)
        answer = sanitize(response['output'])

        if answer not in list_of_candidates:
            extracted = sanitize(extractor(response['output'], question))
            if extracted not in list_of_candidates:
                em.append(False)
                continue 
        em.append(True)

    print(f"Temperature: {temp}\nRepetition Penalty: {rep}\nScore: {sum(em)}\n")
    results_v2.append({'temp':temp, 'rep':rep, 'score':sum(em)})

with open("results_v2.json", "w") as f:
    json.dump(results_v2, f)


### Run experiments

In [None]:
def process_chunk(model, data, prompter, lock, thread_id, batch_offset):
    print(f"Thread {thread_id} started.\n")
    # Initialize the agent and the extractor
    agent = PWS(model)
    extractor = Extractor(model)
    # Initilize loop variables
    batch_id = thread_id
    results = []
    for i, (question, answer) in enumerate(data):
        # Process and save results for each batch
        if i and not i % 100:
            acc = sum([result['em'] for result in results]) / 100
            print(f"Processed batch number {batch_id} with {acc} accuracy.")
            with open(f"results/results_batch_{batch_id}.json", "w") as f:
                json.dump(results, f)
            batch_id += batch_offset
            results = []
        # Select examples using the prompter
        lock.acquire()
        selection = prompter.select_examples(question, NUM_EXAMPLES)
        lock.release()
        # Run the agent
        examples = [entry['metadata'] for entry in selection]
        response = agent.run(question, examples, verbose=True)
        # Check the correctness of the answer
        list_of_candidates = [sanitize(alias) for alias in answer["aliases"]]
        if sanitize(response['output']) in list_of_candidates:
            em = True
        else:
            # Try extracting the answer from the output
            extracted_output = extractor(response['output'], question)
            if sanitize(extracted_output) in list_of_candidates:
                em = True
            else:
                em = False
        
        instructions = [{'id': entry['id'],
                         'similarity': entry['score']
                        }
                        for entry in selection]
        results.append({'em': em, 'instructions': instructions})
        # In case of an exact match, add the new plans to the index
        # and increment the scores of the selected instructions
        if em:
            # Aggregate the tools used for this instance
            tools = set()
            for calls in response['planner_response']['tool_calls'].values():
                tool = calls.split('[', 1)[0]
                tools.add(tool)
            tools = list(tools)
            # Metadata for the new plans
            new_entry_metadata = {'question': question,
                                  'plan': response['planner_response']['text'],
                                  'tools': tools,
                                  'dataset_name': DATASET_NAME,  
            }
            lock.acquire()
            # Add new plans to the index 
            prompter.upsert_entry(new_entry_metadata)
            # Increment scores of the selected instructions
            for entry in selection:
                prompter.increment_score(entry['id'])
            lock.release()
    # Process and save results for the last batch    
    acc = sum([result['em'] for result in results]) / len(results)
    print(f"Processed batch number {batch_id} with {acc} accuracy.")
    with open(f"results/results_batch_{batch_id}.json", "w") as f:
        json.dump(results, f)


In [None]:
if DEVICE_COUNT == 'auto':
    device_count = torch.cuda.device_count()
else:
    device_count = DEVICE_COUNT
    
dataset_size = dataset['train'].num_rows
chunk_size = int(math.ceil(dataset_size / device_count))
lock = Lock()
threads = []
for device in range(device_count):
    chunk = dataset['train'][(device * chunk_size):((device + 1) * chunk_size)]
    data = zip(chunk['question'], chunk['answer'])
    model = LanguageModel(MODEL_PATH, generation_config=generation_config,
                          device_map=device, load_in_8bit=LOAD_IN_8BIT, access_token=HF_TOKEN,
                          system_tag=SYSTEM_TAG, user_tag=USER_TAG, ai_tag=AI_TAG)
    args = args=(model, data, prompter, lock, device, device_count)
    threads.append(Thread(target=process_chunk, args=args))

os.mkdir('results')
for thread in threads:
    thread.start()
    