In [None]:
import pandas as pd
import re
from huggingface_hub import InferenceClient
import numpy as np
import ast
import time
import threading
from tqdm.notebook import tqdm
import itertools

In [None]:
df_selected = pd.read_csv('../1_data/selected_papers_normalized.csv')
# df_selected['embeddings'] = df_selected['embeddings'].apply(ast.literal_eval)

In [None]:
print(df_selected.shape)
df_selected.head(2)

In [None]:
import re
import time

def correct_llama_parallel(words, client, y, responses, progress_bar, system, instruct, indexes_analized, indexes_errors):
    """
    Processes a list of words in parallel using a text generation client and updates the responses and progress bar.

    Args:
        words (list of tuples): A list of tuples where each tuple contains an index and a word to be processed.
        client (object): The text generation client used to generate responses.
        y (int): The index of the current batch being processed.
        responses (dict): A dictionary to store the responses for each batch.
        progress_bar (object): A progress bar object to update the progress of the processing.
        system (str): The system prompt to be used in the text generation.
        instruct (str): The instruction template to be used in the text generation.
        indexes_analized (dict): A dictionary to store the indexes of the words that have been analyzed.
        indexes_errors (dict): A dictionary to store the indexes of the words that encountered errors during processing.

    Returns:
        None
    """

    # Template for the prompt to be sent to the language model.
    template = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>{system}<|eot_id|><|start_header_id|>user<|end_header_id|>{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    
    response = []
    consecutives_errors = 0

    indexes_viewed = []
    indexes_errors_core = []

    for i, (index, word) in enumerate(words):
        indexes_viewed.append(index)

        # Check for too many consecutive errors and interrupt if necessary (sometimes LLM model starts to hallucinate and providing non-sense responses).
        if consecutives_errors > 10:
            print("Too many consecutive errors. Interrupting...")

            responses[y] = [y, response]
            progress_bar.close()
            break

        # Format the instruction with the current word.
        instruct_w = instruct.format(word=word)
        prompt = template.format(system=system, user=instruct_w)

        out = None
        error_sum = 0

        # Retry the text generation until a valid response is received or the maximum number of retries is reached.
        while out is None:
            try:
                # Generate text using the client.
                out = client.text_generation(prompt, max_new_tokens=500, temperature=0.001, do_sample=False, top_p=0.01, top_k=1)

                # Extract and clean the types of losses from the generated text using regular expressions.
                match = re.search(r'(type[s]? of ?(losses|loss)\**\s*\n*[=|:]*\s*\**)(\S.+)', out.lower())
                if match:
                    text_no_parentheses = re.sub(r'\([^)]*\)', '', match.group(3))
                    text_clean = re.sub(r'[^\w\s,]', '', text_no_parentheses)
                    types_losses = [loss for loss in re.split(r",\s*", text_clean) if loss]
                    str_types_losses = "none" if "none" in str(types_losses).lower() else str(types_losses).replace("'", "").replace("[", "").replace("]", "")
                    out += "\n" + str_types_losses

            except Exception as e:
                # Handle rate limit and other exceptions.
                if "Rate limit reached." in str(e) or "Max retries exceeded with url" in str(e):
                    print(str(e))
                    print("Sleeping 10 min at", time.strftime("%Y-%m-%d %H:%M:%S"))
                    time.sleep(600)
                else:
                    print("\n\nSleeping 5 seconds\n\n")
                    print(str(e))
                    time.sleep(5)
                out = None

            # If too many errors have occurred, skip the current word.
            if out is not None and error_sum > 5:
                print("Too many errors, skipping")
                out = "LAST ERROR: " + out
                print(out)
                indexes_errors_core.append(index)
                break

            # Check if the generated text is valid (i.e. if it was able to extract the loss type and if the number of loss types extracted is less than 3).
            if out is not None and (not match or len(types_losses) > 3):
                error_sum += 1
                out = None

        # Update the consecutive error count.
        if "LAST ERROR" in out:
            consecutives_errors += 1
        else:
            consecutives_errors = 0

        response.append(out)
        progress_bar.update(1)
    
    # Store the indexes of analyzed and erroneous words.
    indexes_errors[y] = indexes_errors_core
    indexes_analized[y] = indexes_viewed
    responses[y] = [y, response]
    progress_bar.close()

# Example usage (replace with your actual values)
system = "You are a helpful assistant specialized in summarizing scientific articles that studied a type of loss in decision-making."

instruct = """For the following scientific article, provide both (1) a concise one-sentence summary of the article's main content and (2) the specific type of loss that is explicitly studied or used in the article.

For example, if the article uses monetary incentives to study choices between gambles, your response would be:
*Summary*: The article studies the decision-making in monetary gambles.
*Type of loss*: Financial

Use concise, simple and everyday language.

If multiple types of losses are used, list ONLY the most important ones for the study (and ONLY up to *THREE*, the rest will be discarded) separed by commas. Within the types, avoid using the word "loss" or similar ones. Avoid examples or clarifications between parenthesis. *Do not infer or assume any additional types beyond what is explicitly stated*.

Use the *exact* following format:
*Summary*: [summary]
*Type of loss*: [type of loss]


Scientific article:
{word}\n\n"""

In [None]:
def split_list(lst, n_parts):
    n = len(lst)
    k = n // n_parts  
    r = n % n_parts 

    parts = []
    start = 0
    for i in range(n_parts):
        end = start + k + (1 if i < r else 0)
        parts.append(lst[start:end])
        start = end
    return parts

In [None]:
threads = []
workers = 5
n_responses = workers
llama_resp = [[] for _ in range(workers)]
indexes_analized = [[] for _ in range(workers)]
indexes_errors = [[] for _ in range(workers)]
candidates_total = [(index, f"Title: {row['Title']}\n Abstract: {row['Abstract']}") for index, row in df_selected.iterrows()]
candidates = split_list(candidates_total, workers)

# Create as many clients as workers and as many workers as tokens you have and differents runs oyu want
client1 = InferenceClient(token=TOKEN1,
                         model="meta-llama/Meta-Llama-3.1-70B-Instruct",
                         headers={"X-use-cache": "false"})
client2 = InferenceClient(token=TOKEN2,
                         model="meta-llama/Meta-Llama-3.1-70B-Instruct",
                         headers={"X-use-cache": "false"})
client3 = InferenceClient(token=TOKEN3,
                         model="meta-llama/Meta-Llama-3.1-70B-Instruct",
                         headers={"X-use-cache": "false"})
client4 = InferenceClient(token=TOKEN4,
                         model="meta-llama/Meta-Llama-3.1-70B-Instruct",
                         headers={"X-use-cache": "false"})
client5 = InferenceClient(token=TOKEN5,
                         model="meta-llama/Meta-Llama-3.1-70B-Instruct",
                         headers={"X-use-cache": "false"})

progress_bars = [tqdm(total=len(candidates[j]), desc=f"Progress {j}", leave=True) for j in range(len(candidates))]


for y in range(workers):
    thread = threading.Thread(target=correct_llama_parallel, args=(candidates[y], eval(f'client{y+1}'), y, llama_resp, progress_bars[y], system, instruct, indexes_analized, indexes_errors))
    thread.start()
    threads.append(thread)


for thread in threads:
    thread.join()

In [None]:
# Flatthen the list
llama_resp_no_core = []
for resp in llama_resp:
    llama_resp_no_core.append(resp[1])

llama_resp_no_core = list(itertools.chain.from_iterable(llama_resp_no_core))
llama_resp_no_core

In [None]:
# Clean responses and find errors
llama_resp_clean = []

errors = []

for i, resp in enumerate(llama_resp_no_core):
    if "LAST ERROR" in resp:
        llama_resp_clean.append("error")
        errors.append(i)
    else:
        # Extract last line, after the last \n
        lines = resp.split("\n")
        last_line = lines[-1]
        llama_resp_clean.append(last_line)

    
print(llama_resp_clean)
print(errors)
print(len(errors))

In [None]:
# Correct errors
candidates_errors = split_list([(index, f"Title: {row['Title']}\n Abstract: {row['Abstract']}") for index, row in df_selected.iterrows() if index in errors], workers)
corrected_errors = [[] for _ in range(workers)]
indexes_analized = [[] for _ in range(workers)]
indexes_errors = [[] for _ in range(workers)]

progress_bars = [tqdm(total=len(candidates_errors[j]), desc=f"Progress {j}", leave=True) for j in range(len(candidates_errors))]


for y in range(workers):
    thread = threading.Thread(target=correct_llama_parallel, args=(candidates_errors[y], eval(f'client{y+1}'), y, corrected_errors, progress_bars[y], system, instruct, indexes_analized, indexes_errors))
    thread.start()
    threads.append(thread)


for thread in threads:
    thread.join()

# Flatten the list of errors
corrected_errors_no_core = []
for error in corrected_errors:
    corrected_errors_no_core.append(error[1])

corrected_errors_no_core = list(itertools.chain.from_iterable(corrected_errors_no_core))

for i, error in enumerate(errors):
    llama_resp_no_core[error] = corrected_errors_no_core[i]

In [None]:
count = 0
for resp in llama_resp_clean:
    if "none" in resp.lower():
        count += 1
print(count)

In [None]:
df_selected['normalized'] = llama_resp_clean
df_selected.head()

In [None]:
df_selected.to_csv('../1_data/selected_papers_normalized.csv', index=False)