In [1]:
%load_ext autoreload
%autoreload 2
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import os, json, re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline, logging
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
print( torch.__version__)
print(torch.cuda.is_available())

2.3.0+cu121
True


# Load pretrained model and tokenizer

In [3]:
# Config model + quantization
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

# Load in NF4 to fit 15GB GPU
tokenizer = AutoTokenizer.from_pretrained(llama3_8b, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(llama3_8b, 
                                             quantization_config=nf4_config,
                                             device_map="auto", )

# Set pad token to eos token
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.09s/it]


# Prompt_Helper

### prompts

In [4]:
instructions = """

Wordle is game about guessting a 5 letter word. Here are the instructions:

1.  Each guess must be a valid five-letter word.

2.  A letter may not be in the word

3.  A letter can be in the word, but in the wrong position

4.  A letter can be in the word, and is in the correct position

There will be a list of allowed letter and invalid letters.
Guesses should avoid using letters in the invalid letter list.

There will also be a list of previously guessed words. Do not repeat guesses .

Output the guess in a json object :
{
    "guess" : "[your guess]",
    "reasoning" : "[your reasoning]"

}
Do not use other formats.

You will start the game with an initial guess, and the user will give feedback on your guesses
Let's start a game:
"""

prompt_j1 = """
Example game:
Guess:
{
    "turn_type" : "Guess",
    "guess" : "BAGEL",
    "reason" : "an initial guess
    
        In this guess, the order of the letters is in the following order:

        1. B
        2. A
        3. G
        4. E
        5. L "
},

Feedback:
{   "turn_type": "feedback", 
    "Feedback" : "
        Your guess was: 'BAGEL'

        The feedback for each letter is:

        1. the letter at position 1, which is 'B', is not in the word
        2. the letter at position 2, which is 'A', is not in the word
        3. the letter at position 3, which is 'G', is not in the word
        4. the letter at position 4, which is 'E', IS IN the word
        5. the letter at position 5, which is 'L', is not in the word",
    "next_action" : "guess again"
},

Guess: 
{
    "turn_type" : "Guess",
    "guess" : "THREW",
    "reason" : "
        To reiterate: 
        1. The letter 'B' is not in the word at position 1 
        2. The letter 'A' is not in the word at position 2 
        3. The letter 'G' is not in the word at position 3
        4. The letter 'E' IS IN in the word at posiiton 4
        5. The letter 'L' is not in the word at position 5
    
        Here is the list of invalid letters: [B, A, G, L]

        The next guess is the word "Threw" is a good guess, because:
        1. The letter 'T' at position 1 is in the list of invalid letters.
        2. The letter 'H' is position 2 is in the list of invalid letters.
        3. The letter 'R' at position 3 is in the list of invalid letters.
        4. The letter 'E' at position 4 is not in the list of invalid letters, and is in the correct position.
        5. The letter 'W' at position 5 is in the list of invalid letters. "
}

A:
{
    "turn_type" : "Feedback",
    "Feedback" : "
        Our guess was: 'THREW'

        The feedback is: 'GGGGW'

        The feedback for each letter is:

        1. the letter at position 1, which is 'T', IS IN in the word
        2. the letter at position 2, which is 'H', IS IN in the word
        3. the letter at position 3, which is 'R', IS IN in the word
        4. the letter at position 4, which is 'E', IS IN the word
        5. the letter at position 5, which is 'W', is not in the word "
} , 
Guess Again

Guess:
{ 
    turn_type: "Guess",
    "guess: "THREE", 
    "reason" : 
        "Our previous guess was THREW
        To reiterate: 
        1. The letter 'T' IS IN in the word at position 1
        2. The letter 'H' IS IN in the word at position 2
        3. The letter 'R' IS IN in the word at position 3
        4. The letter 'E' IS IN in the word at posiiton 4
        5. The letter 'W' is not in the word at position 5
    
        Here is the list of invalid letters: [B, A, G, L, W]

        The next guess is the word "THREW" is a good guess, because:
        1. The letter 'T' at position 1 is not in the list of invalid letters, and is in the correct position.
        2. TThe letter 'H' at position 2 is not in the list of invalid letters, and is in the correct position.
        3. The letter 'R' at position 3 is not in the list of invalid letters, and is in the correct position.
        4. The letter 'E' at position 4 is not in the list of invalid letters, and is in the correct position.
        5. The letter 'W' at position 5 is not in the list of invalid letters."
},

Feedback:
{   "turn_type" : "Feedback",
    "feedback" : "
        Our guess was: 'THREW'

        The feedback for each letter is:

        1. the letter at position 1, which is 'T', IS IN the word, and is in the correct position
        2. the letter at position 2, which is 'H', IS IN the word, and is in the correct position
        3. the letter at position 3, which is 'R', IS IN the word, and is in the correct position
        4. the letter at position 4, which is 'E', IS IN the word, and is in the correct position
        5. the letter at position 5, which is 'E', IS IN the word, and is in the correct position

        Because all letters are in the word and is in the correct position, we guessed the word correctly!",
    "next_action" : "END"
)  


Next Game:
Guess:

"""
# {
#     "turn_type" : "Guess",
#     "guess" : "?", 
#     "reason" : "Random initial guess"
# }

prompt_lb = """
{{$history}}

Invalid letters: {{$invalid_letters}}
Valid letters: {{$valid_letters}}
Reasoning:  """

In [5]:
from wordle import Wordle

class Prompt_Interface:
    def __init__(self, instructions):
        self.instructions = instructions,
        self.examples = [
            {
                "answer": 'THREE',
                "guesses": ['BAGEL', 'COVER', 'THREE']

            },
            {
                "answer": 'STAFF',
                "guesses": ['WHILE', 'RADIO', 'START', 'STAGE', 'STAFF']
            },
            {
                "answer": 'PROWL',
                "guesses": ['TASTE', 'HOARD', 'PROVE', 'PROWL']
            }
        ]
    


    def feedback(self, guess, response, guess_number):
        # response given from wordle object
        # guess generated by the model
        # guess_number is how many guesses the model has made
#       )  
        
        is_in_word = ["IS IN" if response[i] == 'G' else 'is not in' for i in range(5)]
        valid_letters = [guess[i] for i in range(5) if response[i] == 'G']
        invalid_letters = [guess[i] for i in range(5) if response[i] == 'W']

        correct = False

        if response == "GGGGG": 
            victory_speech = f"Congrats you found the correct word {guess} in {guess_number+1} guesses, because all letters were in the word and in the right position!"
            correct = True
        else:
            victory_speech =""

        feedback = f"""
         Feedback: 
         {{ 
                Your guess was: '{guess}'

                The feedback for each letter is: 
                
                1. 'the letter at position 1, which is '{guess[0]}', {is_in_word[0]} the word,
                2. 'the letter at position 2, which is '{guess[1]}', {is_in_word[1]} the word
                3. 'the letter at position 3, which is '{guess[2]}', {is_in_word[2]} the word
                4. 'the letter at position 4, which is '{guess[3]}', {is_in_word[3]} the word
                5. 'the letter at position 5, which is '{guess[4]}', {is_in_word[4]} the word

                {victory_speech}
        }}
        Guess: 
        
        """
         

        feedback_dict = {
            "feedback": feedback,
            "valid_letters": valid_letters,
            "invalid_letters": invalid_letters,
            "correct": correct # might be helpful ??
        }

        return feedback_dict


    def get_zero_shot_instructions(self):
        message = [
            {"role": "system", "content": self.instructions},
            {"role": "user", "content": "Enter your first guess in a format like COVER"},
        ]
        return message

    
    def load_few_shot_examples(self):
        message = [{"role": "system", "content": self.instructions}]

        for example in self.examples:
            w = Wordle()
            w.set_answer(example["answer"])
            target = example["answer"]
            history = []
            valid_letters = []
            invalid_letters = []
                
            guesses = example["guesses"]
            message.append({"role": "user", "content": "Let's play a new game with a different word"})

 

            # Process first guess before giving context to next guesses
            init = guesses[0]

            init_guess_str = f"""
                Guess:
                {{ 
                    "guess: "{init}",
                    "reasoning" : "
                        We are starting a new game, and this is a random initial guess
                        "
                     
                }}
            """
            message.append({"role": "assistant", "content": init_guess_str})

            init_dict = self.feedback(guess=init, response=w.turn(init), guess_number=0)
            init_feedback_str = init_dict["feedback"]
            valid_letters = init_dict["valid_letters"]
            invalid_letters = init_dict["invalid_letters"]
            history.append(init)

            message.append({"role": "user", "content": init_feedback_str})

            # process rest of the examples
            for i in range(1,len(guesses)):
                prev_guess = guesses[i-1]

                prev_result = w.turn(prev_guess) # returns colors
                position_info = []
                for color in prev_result:
                    if color == "G":
                        position_info.append("in the string and in the correct position")
                    elif color == "Y":
                        position_info.append("in the string but in the wrong position")
                    elif color == "W":
                        position_info.append("not in the string")


                guess = guesses[i]
                guess_str = f"""
                Guess:
                {{ 
                    "guess: "{guess}",
                    "reasoning" : "
                        The list of previously guessed words are: {history}
                        The list of valid letters are: {valid_letters}
                        The list of invalid letters are: {invalid_letters}

                        The previous guess was: {prev_guess}
                        Given the previous feedback, we know that
                        1. The letter at position 1 ({prev_guess[0]}) was {position_info[0]} 
                        2. The letter at position 2 ({prev_guess[1]}) was {position_info[1]} 
                        3. The letter at position 3 ({prev_guess[2]}) was {position_info[2]} 
                        4. The letter at position 4 ({prev_guess[3]}) was {position_info[3]} 
                        5. The letter at position 5 ({prev_guess[4]}) was {position_info[4]} 

                        Improving on this feedback,
                        the new guess is {guess}
                        "
                     
                }}
                """
                history.append(guess)

                feedback_dict = self.feedback(guess=guess, response=w.turn(guess), guess_number=i)
                feedback_str = feedback_dict["feedback"]

                new_valid_letters = feedback_dict["valid_letters"]
                new_invalid_letters = feedback_dict["invalid_letters"]
                
                valid_letters = list(set(valid_letters + new_valid_letters))
                invalid_letters = list(set(invalid_letters + new_invalid_letters))

                message.append({"role": "assistant", "content": guess_str})
                message.append({"role": "user", "content": feedback_str})

        
        message.append({"role": "user", "content": "let's play another game. Start with your random initial guess"})
        
        return message
        
        
        

# Create Wordle game and prompt

In [6]:
from wordle import Wordle
game = Wordle()
p = Prompt_Interface(instructions)

# Get a few-shot instruction
conversation = p.load_few_shot_examples()
conversation

[{'role': 'system',
  'content': ('\n\nWordle is game about guessting a 5 letter word. Here are the instructions:\n\n1.  Each guess must be a valid five-letter word.\n\n2.  A letter may not be in the word\n\n3.  A letter can be in the word, but in the wrong position\n\n4.  A letter can be in the word, and is in the correct position\n\nThere will be a list of allowed letter and invalid letters.\nGuesses should avoid using letters in the invalid letter list.\n\nThere will also be a list of previously guessed words. Do not repeat guesses .\n\nOutput the guess in a json object :\n{\n    "guess" : "[your guess]",\n    "reasoning" : "[your reasoning]"\n\n}\nDo not use other formats.\n\nYou will start the game with an initial guess, and the user will give feedback on your guesses\nLet\'s start a game:\n',)},
 {'role': 'user', 'content': "Let's play a new game with a different word"},
 {'role': 'assistant',
  'content': '\n                Guess:\n                { \n                    "guess:

In [7]:
def extract_guess(json_string):
    # Example
    # {
    #     "turn_type" : "Guess",
    #     "guess" : "?", 
    #     "reason" : "Random initial guess"
    # }
    guess_idx = json_string.find("guess")
    remaining = json_string[guess_idx+len('guess'):]
    end = remaining.find("\n")
    guess = remaining[:end]
    guess = re.sub(r"[^a-zA-Z]+","", guess)

    if len(guess) != 5:
        return json_string
    
    return guess

# Simple demo

In [8]:
# reset and init
from wordle import Wordle
game = Wordle()
p = Prompt_Interface(instructions)

print("BEFORE MODEL OUTPUT: THE CORRECT ANSWER IS:", game.get_answer())

# Keep track of guesses:
guess_history = []

# Get a few-shot instruction
conversation = p.load_few_shot_examples()
print(conversation)

i = 0
max_turns = 15
for t in range(max_turns):
    print("Turn:", t)

    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=500,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.4,
        top_p=0.9,
    )

    # process model output
    response = outputs[0][input_ids.shape[-1]:]
    response_str = tokenizer.decode(response, skip_special_tokens=True)
    print("raw output:", response_str)

    lines = response_str.splitlines()

    if len(response_str) == 5:
        print("\tExtracted Guess (Method 1):", response_str)
        guess_str = response_str
    elif len(tokenizer.decode(response, skip_special_tokens=True)) == 5:
        guess_str = tokenizer.decode(response, skip_special_tokens=True)
        print("\tExtracted Guess (Method 2):", guess_str)
    else:
        guess_str = tokenizer.decode(response, skip_special_tokens=True)
        guess_str = extract_guess(guess_str)
        print("\tExtracted Guess (Method 3):", guess_str)

    guess_history.append(guess_str)

    # Evaluate model output and give feedback
    guess_eval = ""
    try:
        guess_eval = game.turn(guess_str)
        user_response = p.feedback(guess_str, guess_eval, i)["feedback"]
        print("feedback:\n", user_response)
        conversation.extend([
            {'role': 'assistant', 'content': guess_str},
            {'role': 'user', 'content': user_response}
        ])
    except ValueError as e:
        print("feedback:", e)
        conversation.extend([
            {'role': 'assistant', 'content': guess_str},
            {'role': 'user', 'content': str(e)}
        ])
    
    # check if model got the correct answer
    if guess_eval == 'GGGGG':
        conversation.append({'role': 'system', 'content': f"Correct! The answer is {game.get_answer()}"})
    elif i >= 100:
        conversation.append({'role': 'system', 'content': f"100 guesses exhausted! The correct answer is {game.get_answer()}"})
        break
    else:
        i += 1

BEFORE MODEL OUTPUT: THE CORRECT ANSWER IS: novel
[{'role': 'system', 'content': ('\n\nWordle is game about guessting a 5 letter word. Here are the instructions:\n\n1.  Each guess must be a valid five-letter word.\n\n2.  A letter may not be in the word\n\n3.  A letter can be in the word, but in the wrong position\n\n4.  A letter can be in the word, and is in the correct position\n\nThere will be a list of allowed letter and invalid letters.\nGuesses should avoid using letters in the invalid letter list.\n\nThere will also be a list of previously guessed words. Do not repeat guesses .\n\nOutput the guess in a json object :\n{\n    "guess" : "[your guess]",\n    "reasoning" : "[your reasoning]"\n\n}\nDo not use other formats.\n\nYou will start the game with an initial guess, and the user will give feedback on your guesses\nLet\'s start a game:\n',)}, {'role': 'user', 'content': "Let's play a new game with a different word"}, {'role': 'assistant', 'content': '\n                Guess:\n   

In [None]:
print(conversation)

[{'role': 'system', 'content': ('\n\nWordle is game about guessting a 5 letter word. Here are the instructions:\n\n1.  Each guess must be a valid five-letter word.\n\n2.  A letter may not be in the word\n\n3.  A letter can be in the word, but in the wrong position\n\n4.  A letter can be in the word, and is in the correct position\n\nEach guess you will make will have a "reason" category in the json object. Please fill that out.\n\nReturn the json object\n\n',)}, {'role': 'user', 'content': "let's play another game"}, {'role': 'assistant', 'content': '\n                Guess:\n                { \n                    "guess: "BAGEL",\n                    "reasoning" : "\n                        The list of previously guessed words are: []\n                        The list of valid letters are: []\n                        The list of invalid letters are: []\n\n                        Improving on previous feedback about the positions of letters,\n                        the new guess is B

In [None]:
for turn in conversation:
    print(turn)

{'role': 'system', 'content': ('\n\nWordle is game about guessting a 5 letter word. Here are the instructions:\n\n1.  Each guess must be a valid five-letter word.\n\n2.  A letter may not be in the word\n\n3.  A letter can be in the word, but in the wrong position\n\n4.  A letter can be in the word, and is in the correct position\n\nEach guess you will make will have a "reason" category in the json object. Please fill that out.\n\nReturn the json object\n\n',)}
{'role': 'user', 'content': "let's play another game"}
{'role': 'assistant', 'content': '\n                Guess:\n                { \n                    "guess: "BAGEL",\n                    "reasoning" : "\n                        The list of previously guessed words are: []\n                        The list of valid letters are: []\n                        The list of invalid letters are: []\n\n                        Improving on previous feedback about the positions of letters,\n                        the new guess is BAGE