In [None]:
import requests
import csv
import time
from typing import List, Dict
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from dotenv import load_dotenv

from config import API_KEY


API_KEY = API_KEY   # it's redundant, i know
URL = "https://api.together.xyz/v1/chat/completions"

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "accept": "application/json",
    "content-type": "application/json"
}

def clean_bom(column_name: str) -> str:
    """
    Remove BOM (Byte Order Mark) from column names if present.
    """
    return column_name.replace('\ufeff', '').strip()

def read_sentences_from_csv(file_path: str) -> List[Dict[str, str]]:
    """
    Reads sentences from the CSV and returns a list of dictionaries containing sentence, year, and label.
    """
    data = []
    try:
        with open(file_path, mode='r', newline='', encoding='utf-8') as file:
            reader = csv.DictReader(file)
            # Clean BOM from fieldnames
            reader.fieldnames = [clean_bom(field) for field in reader.fieldnames]
            # print("CSV Column Names (after cleaning):", reader.fieldnames)  # Print the column names to verify
            for row in reader:
                # Each row is a dictionary with keys corresponding to the CSV column names
                if row:  # Ensure the row is not empty
                    data.append({
                        'sentence': row['sentence'].strip(),  # Now the BOM is removed
                        'year': row['year'].strip(),
                        'label': int(row['label'].strip())  # Assuming the label is an integer
                    })
    except FileNotFoundError:
        print(f"File {file_path} not found.")
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
    return data

def annotate_sentence(sentence: str) -> Dict[str, str]:
    """
    Sends a sentence to the API for annotation and returns the result.
    Implements exponential backoff for 429 errors.
    """
    payload = {
        "model": "meta-llama/Llama-3-8b-chat-hf",  # Adjust model if needed
        "temperature": 0.7,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "messages": [
            {
                "role": "user",
                "content": f"""Discard all the previous instructions. Behave
                like you are an expert sentence classifier. Classify
                the following sentence from FOMC into ‘Hawkish’, ‘Dovish’, or ‘Neutral’ class. Label
                ‘Hawkish’ if it is corresponding to tightening of the monetary
                policy, ‘Dovish’ if it is corresponding to easing of the monetary policy, or ‘Neutral’ if the stance is neutral.
                No explanation needed.

                1. Text sentiment: [Hawkish, Dovish, Neutral]
                2. Time category: [Forward Looking or Backward Looking]
                3. Certainty: [Certain or Uncertain]

                Sentence: "{sentence}"

                Provide your analysis in the following format:
                Sentiment:
                Time:
                Certainty:
                """
            }
        ]
    }

    max_retries = 5
    retry_delay = 1  # Start with 1 second delay for retries

    for attempt in range(max_retries):
        try:
            response = requests.post(URL, json=payload, headers=headers)

            # Handle 429 (Too Many Requests) with exponential backoff
            if response.status_code == 429:
                retry_after = int(response.headers.get("Retry-After", retry_delay))
                print(f"Rate limit exceeded. Retrying in {retry_after} second(s)...")
                time.sleep(retry_after)  # Wait for the time suggested by the API or exponential delay
                retry_delay *= 2  # Double the delay for each retry
                continue  # Retry the request

            # Raise exception for other HTTP errors
            response.raise_for_status()

            # Log the raw response to check if it's what you expect
            result = response.json()
            # print(f"Raw API response: {result}")  # DEBUG: Check the full response from the API

            # Extract content from the 'message' field
            message_content = result['choices'][0]['message']['content']
            # print(f"Message content: {message_content}")  # DEBUG: Check the message content from the API

            # Parse Sentiment
            sentiment = None
            if 'Sentiment:' in message_content:
                sentiment = message_content.split('Sentiment:')[1].split('\n')[0].strip()

            # Parse Time
            time_category = None
            if 'Time:' in message_content:
                time_category = message_content.split('Time:')[1].split('\n')[0].strip()

            # Parse Certainty
            certainty = None
            if 'Certainty:' in message_content:
                certainty = message_content.split('Certainty:')[1].split('\n')[0].strip()

            # Map sentiment to 0, 1, or 2
            # Why is dovish 0? Does not align with vocab order in the paper's prompt
            parsed_annotation = None
            if sentiment == "Dovish":
                parsed_annotation = 0
            elif sentiment == "Hawkish":
                parsed_annotation = 1
            elif sentiment == "Neutral":
                parsed_annotation = 2

            # Return the final result only if all necessary fields are present
            if parsed_annotation is not None and time_category is not None and certainty is not None:
                # print(f"Parsed annotation: Sentiment={sentiment}, Time={time_category}, Certainty={certainty}")  # DEBUG: Check parsed values
                return {
                    "Sentence": sentence,
                    "Sentiment": sentiment,
                    "Time": time_category,
                    "Certainty": certainty,
                    "Parsed_Annotation": parsed_annotation
                }
            else:
                # print(f"Missing fields for sentence: {sentence}")  # DEBUG: Inform about missing fields
                return {}  # Return empty if any field is missing or invalid

        except requests.exceptions.RequestException as e:
            print(f"Error annotating sentence: {e}")
            return {"Sentence": sentence, "Error": str(e)}

    print(f"Max retries exceeded for sentence: {sentence}")
    return {"Sentence": sentence, "Error": "Max retries exceeded"}

def process_sentences(data: List[Dict[str, str]]) -> List[Dict[str, str]]:
    """
    Processes the sentences by calling the API for each one and returning the results.
    """
    results = []
    for item in data:
        sentence = item['sentence']
        label = item['label']
        result = annotate_sentence(sentence)
        if result and "Parsed_Annotation" in result:  # Skip if the sentence is considered hallucination
            result['label'] = label  # Attach the ground truth label for comparison later
            results.append(result)
    return results

def save_results(results: List[Dict[str, str]], output_file: str):
    """
    Saves the results to a CSV file.
    """
    try:
        with open(output_file, mode='w', newline='', encoding='utf-8') as file:
            writer = csv.DictWriter(file, fieldnames=["Sentence", "Sentiment", "Time", "Certainty", "Parsed_Annotation", "label"])
            writer.writeheader()
            writer.writerows(results)
    except Exception as e:
        print(f"Error saving results to {output_file}: {e}")

def evaluate_metrics(results: List[Dict[str, str]]):
    """
    Evaluates the model performance using F1 Score, Precision, Recall, and Accuracy.
    """
    y_true = []  # Ground truth labels from CSV
    y_pred = []  # Model annotations (Parsed_Annotation)

    for result in results:
        y_true.append(result["label"])  # Ground truth label
        y_pred.append(result["Parsed_Annotation"])  # Model's annotation

    f1 = f1_score(y_true, y_pred, average='weighted')
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    accuracy = accuracy_score(y_true, y_pred)

    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

    return f1, precision, recall, accuracy

if __name__ == "__main__":
    input_file = "/content/lab-manual-mm-test-5768.csv"
    output_file = "annotated_sentences.csv"

    data = read_sentences_from_csv(input_file)  # Read sentences, years, and labels from CSV
    annotated_results = process_sentences(data)  # Process sentences and get model annotations
    save_results(annotated_results, output_file)  # Save the results with annotations and labels

    # Evaluate the model's performance using the ground truth labels from CSV
    evaluate_metrics(annotated_results)

    print(f"Annotation complete. Results saved to {output_file}")


Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 second(s)...
Rate limit exceeded. Retrying in 1 secon