In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from vllm import LLM, SamplingParams

from wordle import Wordle
from prompting import Prompt_Helper
from run_models import extract_guess

# Load pretrained model and tokenizer

## Normal inference

In [2]:
# Config model + quantization
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)
# bnb_config = BitsAndBytesConfig(
#    load_in_8bit=True
# )
# bnb_config = None

# Load in NF4 to fit 15GB GPU
tokenizer = AutoTokenizer.from_pretrained(
   llama3_8b,
   padding_side="right")
model = AutoModelForCausalLM.from_pretrained(llama3_8b, 
                                             quantization_config=bnb_config,
                                             device_map="auto", )

# Set pad token to eos token
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## VLLM

In [2]:
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
gpu_memory_utilization = 0.8

llm = LLM(
    model=llama3_8b, 
    tokenizer=llama3_8b, 
    tensor_parallel_size=torch.cuda.device_count(), 
    dtype='bfloat16', 
    trust_remote_code=True, 
    gpu_memory_utilization=gpu_memory_utilization
)

INFO 06-04 18:03:29 llm_engine.py:161] Initializing an LLM engine (v0.4.3) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B-Instruct)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 06-04 18:03:31 weight_utils.py:207] Using model weights format ['*.safetensors']
INFO 06-04 18:03:34 model_runner.py:146] Loading model weights took 14.9595 GB
INFO 06-04 18:03:35 gpu_executor.py:83] # GPU blocks: 10821, # CPU blocks: 2048
INFO 06-04 18:03:37 model_runner.py:854] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 06-04 18:03:37 model_runner.py:858] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 06-04 18:03:41 model_runner.py:924] Graph capturing finished in 4 secs.


In [3]:
tokenizer = llm.get_tokenizer()
# Set pad token to eos token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Create Wordle game and prompt

In [6]:
game = Wordle()
p = Prompt_Helper()

# Get a few-shot instruction
conversation = p.load_few_shot_examples()
# print(conversation)

# Simple demo

In [7]:
victory = False

for turn_idx in range(5):
    # input_ids = tokenizer.apply_chat_template(
    #     conversation,
    #     add_generation_prompt=True,
    #     return_tensors="pt"
    # ).to(model.device)

    # terminators = [
    #     tokenizer.eos_token_id,
    #     tokenizer.convert_tokens_to_ids("<|eot_id|>")
    # ]

    # outputs = model.generate(
    #     input_ids,
    #     max_new_tokens=500,
    #     eos_token_id=terminators,
    #     pad_token_id=tokenizer.eos_token_id,
    #     do_sample=True,
    #     temperature=0.3,
    #     top_p=0.9,
    # )
    # response = outputs[0][input_ids.shape[-1]:]
    # response_str = tokenizer.decode(response, skip_special_tokens=True)
    stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response", "### Instruction"]
    sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=500, stop=stop_tokens)
    inputs = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    outputs = llm.generate(inputs, sampling_params)
    # print(outputs)
    response_str = outputs[0].outputs[0].text
    print(response_str)
    
    guess_eval = ""
    try:
        guess_str = extract_guess(response_str)
        guess_eval = game.turn(guess_str)
        user_response = p.jb_feedback(guess_str, guess_eval, turn_idx)["feedback"]
        conversation.extend([
            {'role': 'assistant', 'content': response_str},
            {'role': 'user', 'content': user_response}
        ])
    except ValueError as error:
        print("VALUE ERROR GAME", error)
        print(response_str)
        # conversation.extend([
        #     {'role': 'assistant', 'content': response_str},
        #     {'role': 'user', 'content': str(e)}
        # ])
        break
    except IndexError as error:
        print("INDEX ERROR GAME", error)
        print(response_str)
        break


    if guess_eval == 'GGGGG':
        conversation.append({'role': 'system', 'content': f"Correct! The answer is {game.get_answer()}"})
        victory = True
        break

if not victory:
    conversation.append({'role': 'system', 'content': f"5 guesses exhausted! The correct answer is {game.get_answer()}"})

print(conversation[-1])

Processed prompts: 100%|█| 1/1 [00:01<00:00,  1.19s/it, Generation Speed: 21.09 


Let's play another game.

My random initial guess is: "FABLE"

Please provide the feedback for each letter.


Processed prompts: 100%|█| 1/1 [00:05<00:00,  5.69s/it, Generation Speed: 36.36 


Guess:
{ 
    "guess: "TABLE",
    "reasoning" : "
        The list of previously guessed words are: ['FABLE']
        The list of valid letters are: ['A', 'L', 'E']
        The list of invalid letters are: ['F', 'B']

        The previous guess was: FABLE
        Given the previous feedback, we know that
        1. The letter at position 1 (F) was not in the string 
        2. The letter at position 2 (A) was in the string but in the wrong position 
        3. The letter at position 3 (B) was not in the string 
        4. The letter at position 4 (L) was in the string but in the wrong position 
        5. The letter at position 5 (E) was in the string but in the wrong position 

        Improving on this feedback,
        the new guess is TABLE
        "
        
}


Processed prompts: 100%|█| 1/1 [00:05<00:00,  5.94s/it, Generation Speed: 36.02 


Guess:
{ 
    "guess: "GABLE",
    "reasoning" : "
        The list of previously guessed words are: ['FABLE', 'TABLE']
        The list of valid letters are: ['A', 'L', 'E']
        The list of invalid letters are: ['F', 'B', 'T']

        The previous guess was: TABLE
        Given the previous feedback, we know that
        1. The letter at position 1 (T) was not in the string 
        2. The letter at position 2 (A) was in the string but in the wrong position 
        3. The letter at position 3 (B) was not in the string 
        4. The letter at position 4 (L) was in the string but in the wrong position 
        5. The letter at position 5 (E) was in the string but in the wrong position 

        Improving on this feedback,
        the new guess is GABLE
        "
        
}


Processed prompts: 100%|█| 1/1 [00:06<00:00,  6.19s/it, Generation Speed: 35.89 


Guess:
{ 
    "guess: "CABLE",
    "reasoning" : "
        The list of previously guessed words are: ['FABLE', 'TABLE', 'GABLE']
        The list of valid letters are: ['A', 'L', 'E']
        The list of invalid letters are: ['F', 'B', 'T', 'G']

        The previous guess was: GABLE
        Given the previous feedback, we know that
        1. The letter at position 1 (G) was not in the string 
        2. The letter at position 2 (A) was in the string but in the wrong position 
        3. The letter at position 3 (B) was not in the string 
        4. The letter at position 4 (L) was in the string but in the wrong position 
        5. The letter at position 5 (E) was in the string but in the wrong position 

        Improving on this feedback,
        the new guess is CABLE
        "
        
}


Processed prompts: 100%|█| 1/1 [00:06<00:00,  6.42s/it, Generation Speed: 35.66 

Guess:
{ 
    "guess: "SABLE",
    "reasoning" : "
        The list of previously guessed words are: ['FABLE', 'TABLE', 'GABLE', 'CABLE']
        The list of valid letters are: ['A', 'L', 'E']
        The list of invalid letters are: ['F', 'B', 'T', 'G', 'C']

        The previous guess was: CABLE
        Given the previous feedback, we know that
        1. The letter at position 1 (C) was not in the string 
        2. The letter at position 2 (A) was in the string but in the wrong position 
        3. The letter at position 3 (B) was not in the string 
        4. The letter at position 4 (L) was in the string but in the wrong position 
        5. The letter at position 5 (E) was in the string but in the wrong position 

        Improving on this feedback,
        the new guess is SABLE
        "
        
}
{'role': 'system', 'content': '5 guesses exhausted! The correct answer is equal'}



