In [None]:
import subprocess
import sys


!pip install torch -U
!pip install transformers -U
!pip install -i https://pypi.org/simple/ -U bitsandbytes
!pip install accelerate -U


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
torch.cuda.empty_cache()

In [None]:
model_name = "abacusai/Llama-3-Smaug-8B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quanty_type = "fp4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quanty = True,
)

amodel = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config = bnb_config,
    torch_dtype = torch.float16,
    device_map = "auto",
    trust_remote_code = True,
)
atokenizer = AutoTokenizer.from_pretrained(model_name)


In [None]:
model_name = "Qwen/Qwen2-7B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quanty_type = "fp4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quanty = True,
)

bmodel = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config = bnb_config,
    torch_dtype = torch.bfloat16,
    device_map = "auto",
    trust_remote_code = True,
)
btokenizer = AutoTokenizer.from_pretrained(model_name)


In [None]:
keywords = pd.read_csv('datasets/finalkeywords.csv')

In [None]:
VERBOSE = False
questions = []
answers = []
guesses = []
turns = 0
thingCategory = False
guessWithModel = True
guesses_set = set()  # Set to track unique guesses
import string

def generate_text(prompt, sys_prompt="", max_new_tokens=700, top_p=0.9, top_k=50, isGuesser=False):
    global amodel, atokenizer, bmodel, btokenizer
    if isGuesser:
        model = amodel
        tokenizer = atokenizer
    else:
        model = bmodel
        tokenizer = btokenizer

    """Generate text using the model based on the given prompt and system prompt."""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    generated_ids = model.generate(
        model_inputs.input_ids,
        attention_mask=model_inputs.attention_mask,
        pad_token_id=pad_token_id,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k
    )

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    response = response.strip()

    # Clean up response
    if "assistant" in response:
        response = response.replace("assistant", "").strip()
    if "system" in response:
        response = response.replace("system", "").strip()

    # Post-process the response to remove extraneous text
    response_lines = response.split('\n')
    if len(response_lines) > 1:
        response = response_lines[-1].strip()

    return response

def context_formatter(questions, answers, guesses):
    global amodel, atokenizer
    """Format the context as a string with each question-answer pair and wrong guesses, then truncate it to fit within the model's maximum token length."""

    # Select the appropriate model and tokenizer
    model = amodel
    tokenizer = atokenizer

    # Retrieve the maximum position embeddings from the selected model
    max_length = model.config.max_position_embeddings

    context_pairs = [f"Q: {q} A: {a}" for q, a in zip(questions, answers)]
    context_pairs.append(f"Guessed keywords: {', '.join(guesses)} are wrong")
    formatted_context = ". ".join(context_pairs)

    context_tokens = tokenizer.encode(formatted_context, return_tensors='pt').to('cuda')
    if context_tokens.shape[1] > max_length:
        truncated_context = tokenizer.decode(context_tokens[0, -max_length:], skip_special_tokens=True)
    else:
        truncated_context = formatted_context

    return truncated_context

def generate_question_with_llm(context):
    global turns
    if turns == 0:
        return "Is it a city or country and not a thing or object? If it is an object, thing, or animal say no."
    elif turns == 1:
        if len(answers) > 0 and answers[0].lower() == "yes":
            return "Is it a city?"
        else:
            return "Is it a living thing?"

    prompt = f"""
    You are playing a game of 20 Questions. Generate a new question to help narrow down the possibilities.
    Previous Q&A:
    {context}
    Next question:
    """
    sys_prompt = "Generate a clear and specific yes or no question based on the context provided. Provide only the question."

    question = ""
    for _ in range(5):
        question = generate_text(prompt, sys_prompt, max_new_tokens=900, isGuesser=True)
        if question and question not in questions:
            return question

    return question

def guess_based_on_qa(context):
    global guesses_set  # Use a set to keep track of guesses

    prompt = f"""
    Based on the following questions and answers, guess the keyword:
    {context}
    Your guess:
    """
    sys_prompt = """
    Review the context and give your best guess in one or two words.
    Only provide the keyword, without any additional text or formatting."""

    guess = ""
    max_attempts = 5  # Limit the number of attempts to find a unique guess

    for _ in range(max_attempts):
        guess = generate_text(prompt, sys_prompt, max_new_tokens=5, isGuesser=True)  # Adjust max_new_tokens for the guesser
        guess = guess.translate(str.maketrans('', '', string.punctuation)).strip()

        if guess and guess not in guesses_set:
            guesses_set.add(guess)
            return guess

    # If all attempts failed, return an empty string or a default value
    return guess

def guess_with_dataset(context):
    global keywords, thingCategory, guesses_set  # Use a set to keep track of guesses

    # Filter out guessed keywords
    remaining_keywords = keywords[~keywords['Keyword'].isin(guesses_set)]

    if thingCategory:
        keys = ','.join(remaining_keywords[remaining_keywords['Category'] == 'things']['Keyword'])
    else:
        keys = ','.join(remaining_keywords[remaining_keywords['Category'] == 'place']['Keyword'])

    prompt = f"""
    Based on the following questions and answers, guess the keyword:
    {context}
    You can choose from the following keywords: {keys}
    Your guess:
    """

    sys_prompt = """
    Review the context and the list of potential keywords. Give your best guess from the list in one or two words.
    Only provide the keyword, without any additional text or formatting."""

    guess = ""
    max_attempts = 5  # Limit the number of attempts to find a unique guess

    for _ in range(max_attempts):
        guess = generate_text(prompt, sys_prompt, max_new_tokens=3, isGuesser=True)  # Adjust max_new_tokens for the guesser
        guess = guess.translate(str.maketrans('', '', string.punctuation)).strip()

        if guess not in guesses_set:
            print(guess)
            print(guesses_set)
            guesses_set.add(guess)
            if guess in keys:
                keywords = keywords.drop(keywords[keywords['Keyword'] == guess].index)
            return guess

    # If all attempts failed, return an empty string or a default value
    return guess

def get_yes_no(question, keyword):
    """Get a yes or no answer to the given question based on the specified keyword."""
    prompt = f"""
    You are playing 20 Questions. Answer 'yes' or 'no' for this keyword: {keyword}

    Question: "{question}"

    Rules:
    1. Consider only your keyword: {keyword}.
    2. Answer strictly 'yes' or 'no'. Nothing else.
    3. Do not provide any explanations or additional information.
    Your answer:
    """
    sys_prompt = "Answer strictly 'yes' or 'no' to the following question and nothing else."

    response = generate_text(prompt, sys_prompt, max_new_tokens=1, isGuesser=False)  # Reduced token limit for strict answers

    # Post-process response to ensure it's only "yes" or "no"
    response = response.strip().lower()
    if "yes" in response:
        return "yes"
    elif "no" in response:
        return "no"
    else:
        # In case of unexpected response, default to "no" to ensure safety
        return "no"

def agent_fn(obs, cfg):
    global turns, amodel, questions, answers, guesses, guesses_set, VERBOSE, guessWithModel, thingCategory

    if obs.turnType == "ask":
        context = context_formatter(questions, answers, guesses)
        response = generate_question_with_llm(context)
        questions.append(response)
        turns += 1
    elif obs.turnType == "guess":
        context = context_formatter(questions, answers, guesses)

        if guessWithModel:
            if len(answers) == 1 and answers[0].lower() == 'no':
                thingCategory = True
            response = guess_with_dataset(context)
        else:
            response = guess_based_on_qa(context)
        if response:  # Only add non-empty guesses
            guesses.append(response)
    else:  # obs.turnType == "answer"

        response = get_yes_no(obs.questions[-1], obs.keyword)
        answers.append(response)

    # Display role
    if VERBOSE:
        if obs.turnType == "answer":
            print(f"Team 2 - Answerer - ### Agent LLAMA 8B ###")
        else:
            print(f"\nTeam 2 - Questioner - ### Agent LLAMA 8B ###")
        print(f"OUTPUT = '{response}'")

    return response


In [None]:
torch.cuda.empty_cache()


perform_eval = True
if perform_eval:
    torch.cuda.empty_cache()
    import requests
    import json
    import re
    import typing as t
    import random
    import time
    from IPython.display import display, Markdown
    import signal

    # ===============FETCH LATEST KEYWORDS FROM KAGGLE GITHUB================
    # This does not train the model to look for the keywords.
    # This just fetches the list so it can be used in the evaluation.

    url = "https://raw.githubusercontent.com/Kaggle/kaggle-environments/master/kaggle_environments/envs/llm_20_questions/keywords.py"
    response = requests.get(url)

    if response.status_code == 200:
        match = re.search(r'KEYWORDS_JSON = """(.*?)"""', response.text, re.DOTALL)
        if match:
            json_str = match.group(1)
            keywords_dict = json.loads(json_str)
        else:
            print("Could not find the KEYWORDS_JSON variable in the file.")
    else:
        print("Request failed with status code:", response.status_code)
        import time

    def select_random_keyword(keywords_dict: t.List[dict]) -> str:
        category = random.choice(keywords_dict)
        keyword_dict = random.choice(category['words'])
        return keyword_dict['keyword'].lower()

    #===============20 QUESTIONS EVALUATION SESSION=====================
    def input_timeout(prompt, timeout):
        def timeout_handler(signum, frame):
            raise TimeoutError("Input timed out after {} seconds".format(timeout))

        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(timeout)

        try:
            user_input = input(prompt)
            signal.alarm(0)
            return user_input
        except TimeoutError as e:
            print(e)
            return None

    timeout = 10

    try:
        DEBUG_input = input_timeout("Enable verbose debug mode? (y/n) [default is n] ", timeout)
        DEBUG = DEBUG_input.lower() == 'y' if DEBUG_input else False
    except:
        DEBUG = False
        print("No input received, defaulting to false.")

    class MockObservation:
        def __init__(self, step: int, role: str, turnType: str, keyword: str, category: str, questions: list[str], answers: list[str], guesses: list[str]):
            self.step = step
            self.role = role
            self.turnType = turnType
            self.keyword = keyword
            self.category = category
            self.questions = questions
            self.answers = answers
            self.guesses = guesses

    def test_20_questions():
        global DEBUG
        step = 0
        role = "answerer"
        turnType = "ask"
        keyword = select_random_keyword(keywords_dict)
        category = ""
        questions = []
        answers = []
        guesses = []
        display(Markdown("# Starting 20 questions eval game..."))
        display(Markdown(f"### **Keyword:** {keyword}"))

        for i in range(60):
            obs = MockObservation(step, role, turnType, keyword, category, questions, answers, guesses)

            start_time = time.time()
            response = agent_fn(obs, None)
            end_time = time.time()

            response_time = end_time - start_time
            if response_time > 60:
                display(Markdown(f"**WARNING:** Response time too long and may be disqualified from the game: {response_time:.2f} sec. Make sure you have GPU acceleration enabled in the session options on the right side panel."))
                break

            # Record the response in the appropriate list
            if turnType == 'ask':
                questions.append(response)
                turnType = 'answer'
            elif turnType == 'answer':
                answers.append(response)
                turnType = 'guess'
            elif turnType == 'guess':
                guesses.append(response)
                if response.lower() == keyword.lower():
                    display(Markdown(f"## **Keyword '{keyword}' guessed correctly! Ending game.**"))
                    break
                turnType = 'ask'
                step += 1

            display(Markdown(f"Step {step} | Response: {response} | {response_time:.2f} sec"))
        display(Markdown(f"Final Questions: {', '.join(questions)}"))
        display(Markdown(f"Final Answers: {', '.join(answers)}"))
        display(Markdown(f"Final Guesses: {', '.join(guesses)}"))

    # Run the test
    test_20_questions()
