In [None]:
from __future__ import annotations

import copy
import logging
import json
import re
import os
from functools import lru_cache, partial, wraps
from typing import Tuple
import unicodedata

import openai
from dotenv import load_dotenv
import pandas as pd

from check_llm_answer import (
    TEST_DATA,
    TEST_DATA_EN,
    create_pair_of_sentences_from_combinations,
    load_data
)
from prompts.en.fs_prompt import BASE_PROMPT as fs, CONTENT_PROMPT
from prompts.en.zs_prompt import BASE_PROMPT as zs
from prompts.en.ct_prompt import BASE_PROMPT as ct

load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

API_KEY = os.getenv("TOGETHER_API_KEY")

client = openai.OpenAI(
    api_key=API_KEY,
    base_url="https://api.together.xyz",
)

In [None]:
path_to_data = "dwug_en/data"
sentence_pair_per_words = create_pair_of_sentences_from_combinations(path_to_data)
target_words = TEST_DATA_EN

In [None]:
PROMPTINGS = {
    "zs": zs,
    "fs": fs,
    "ct": ct
}

In [None]:
PROMPTINGS.keys()

In [None]:
queries = partial(
    client.chat.completions.create,
    model="mistralai/Mixtral-8x7B-Instruct-v0.1",
    max_tokens=10,
    temperature=0.7,
)


def cache_with_logging(maxsize=1024) -> callable:
    "like lru_cache but logs cache hits"

    def decorator(func):
        cached_func = lru_cache(maxsize=maxsize)(func)

        @wraps(func)
        def wrapper(*args, **kwargs):
            hits_before = cached_func.cache_info().hits
            result = cached_func(*args, **kwargs)
            hits_after = cached_func.cache_info().hits
            if hits_after > hits_before:
                logger.info(f"Cache hit for args: {args}, kwargs: {kwargs}")
            return result

        return wrapper

    return decorator


# use maxsize = number of query, context pairs or None for unlimited (not in production)
@cache_with_logging(maxsize=0)
def gen_query(target_word: str, sentence1: str, sentence2: str, prompt: str) -> str:
    response = queries(
        messages=[
            PROMPTINGS[prompt],
            {
                "role": "user",
                "content": CONTENT_PROMPT.format(
                    target_word=target_word, sentence1=sentence1, sentence2=sentence2
                ),
            },
        ]
    )
    return response.choices[0].message.content.strip()

In [None]:
version = "v1"
output = "outputs/llama3.1/dwug_en/{folder}"

In [None]:
def extract_number(text: str):
    pattern = r"\b\d+\b"
    numbers = re.findall(pattern, text)

    try:
        return numbers[0]
    except Exception as e:
        return None

In [None]:
%%time

for key in PROMPTINGS.keys():
    for index, word in enumerate(TEST_DATA_EN):
        
        word_ = unicodedata.normalize("NFC", word.strip())
        df = sentence_pair_per_words[word_]

        scores_per_sentence = []

        for row in df:
            tw = word.split("_")[0]
            completition = gen_query(
                tw,
                row.sentence1,
                row.sentence2,
                key,
            )
            score = extract_number(completition)

            ans = {
                "identifier1": row.identifier1,
                "identifier2": row.identifier2,
                "score": "-" if score is None else score
            }
            scores_per_sentence.append(copy.deepcopy(ans))

        with open(f"{output.format(folder=key)}/test.{word}.scores", "w") as f_out:
            json.dump(scores_per_sentence, f_out)


