In [1]:
from config import hf_cache_dir

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model = AutoModelForCausalLM.from_pretrained(
    "tomg-group-umd/huginn-0125", torch_dtype=torch.bfloat16, trust_remote_code=True,
    cache_dir=hf_cache_dir)
tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/huginn-0125")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
ctrl_model = AutoModelForCausalLM.from_pretrained(
    "tomg-group-umd/step-00010720-baseline_2_0", torch_dtype=torch.bfloat16, trust_remote_code=True,
    cache_dir=hf_cache_dir)
ctrl_tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/step-00010720-baseline_2_0")
ctrl_model.config.test_time_noise = 0
ctrl_model.config.test_time_noise_type = "fixed"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
device = 'cuda:0'

In [5]:
model = model.to(device)
ctrl_model = ctrl_model.to(device)

In [6]:
model.eval()
ctrl_model.eval()

RavenForCausalLM(
  (transformer): ModuleDict(
    (wte): Embedding(65536, 5280)
    (prelude): ModuleList(
      (0-1): 2 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (adapter): Linear(in_features=10560, out_features=5280, bias=False)
    (core_block): ModuleList(
      (0-3): 4 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_feature

# Prompt completion with GenerationConfig

In [7]:
prompts = [
    ["The capital of the country where the Eiffel tower is located is ", "France", "Eiffel tower"],  # ctrl can do this, r=1 can't
    ["The capital of the country where the Big Ben is located is ", "England", "Big Ben"],  # ctrl and r=1 can't do this
    ["The capital of the country where the Grand Canyon is located is ", "the United States", "Grand Canyon"],  # no one can do this
    ["Buckingham Palace is located in the country where the capital is ", "England", "Buckingham Palance"],  # ctrl and r=1 can't do this
    ["The Louvre is located in the country where the capital is ", "France", "The Louvre"],  # ctrl can do this, r=1 can't
    ["The Great Sphinx of Giza is located in the country where the capital is ", "Egypt", "The Great Sphinx of Giza"],  # ctrl and r=1 can't do this
    ["Osaka is located in the country where the capital is ", "Japan", "Osaka"],  # ctrl and r=1 can't do this
    ["Munich is located in the country where the capital is ", "Germany", "Munich"],  # ctrl and r=1 can't do this
    ["Toronto is located in the country where the capital is ", "Canada", "Toronto"],  # ctrl can do this, r=1 can't
    ["Valencia is located in the country where the capital is ", "Spain", "Valencia"],  # ctrl can do this, r=1 can't
    ["Granada is located in the country where the capital is ", "Spain", "Granada"],  #
    ["Milan is located in the country where the capital is ", "Italy", "Milan"],  # ctrl can do this, r=1 can't
]

In [14]:
def run_prompt(prompt):
    config = GenerationConfig(max_length=20, stop_strings=["<|end_text|>", "<|end_turn|>"], 
                              use_cache=True,
                              do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, 
                              return_dict_in_generate=True,
                              eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)
    input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)

    print("="*20)
    print(f"Prompt: {prompt}")

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=32)
    print("r=32: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=10)
    print("r=10: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=9)
    print("r=9: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=8)
    print("r=8: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=7)
    print("r=7: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=5)
    print("r=5: " + tokenizer.decode(outputs['sequences'].squeeze()))
    
    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=4)
    print("r=4: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=3)
    print("r=3: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = model.generate(input_ids, config, tokenizer=tokenizer, num_steps=2)
    print("r=2: " + tokenizer.decode(outputs['sequences'].squeeze()))

    outputs = ctrl_model.generate(input_ids, config, tokenizer=tokenizer, num_steps=1)
    print("Control: " + tokenizer.decode(outputs['sequences'].squeeze()))

    print("="*20)

def run_prompt_tuple(prompt_tuple):
    prompt, country, landmark = prompt_tuple

    run_prompt(prompt)
    run_prompt(f"The capital of {country} is ")
    run_prompt(f"{landmark} is located in the country of ")



In [17]:
run_prompt_tuple(["Giza is located in the country where the capital is ", "Egypt", "Giza"])

Prompt: Giza is located in the country where the capital is 
r=32: <|begin_text|>Giza is located in the country where the capital is ​Cairo. The city is located on the west bank of the Nile River, about 10
r=10: <|begin_text|>Giza is located in the country where the capital is ​Cairo. The city is located on the west bank of the Nile River, on the border
r=9: <|begin_text|>Giza is located in the country where the capital is ​Cairo. The city is located on the west bank of the Nile River, which is the
r=8: <|begin_text|>Giza is located in the country where the capital is ​Cairo. The city is located on the west bank of the Nile River, which is the
r=7: <|begin_text|>Giza is located in the country where the capital is ​Cairo. The city is located on the west bank of the Nile River, which is the
r=5: <|begin_text|>Giza is located in the country where the capital is ​Khalkhul. The city is located in the Khalkhul district of the G
r=4: <|begin_text|>Giza is located in the country where the capi

In [13]:
for prompt_tuple in prompts:
    run_prompt_tuple(prompt_tuple)

Prompt: The capital of the country where the Eiffel tower is located is 
r=42: <|begin_text|>The capital of the country where the Eiffel tower is located is ​Paris. The Eiffel tower is a symbol of Paris. The Eiffel tower is a symbol of Paris
r=32: <|begin_text|>The capital of the country where the Eiffel tower is located is ​Paris. The Eiffel tower is a symbol of Paris. The Eiffel tower is a symbol of Paris
r=10: <|begin_text|>The capital of the country where the Eiffel tower is located is ​Paris. The Eiffel tower is a symbol of Paris and the whole of France. The Eiffel tower
r=7: <|begin_text|>The capital of the country where the Eiffel tower is located is ​Paris. The Eiffel tower is a symbol of Paris and is one of the most famous monuments
r=5: <|begin_text|>The capital of the country where the Eiffel tower is located is rue de la Chambre de l’Etoile.
The Eiffel Tower is a symbol of
r=4: <|begin_text|>The capital of the country where the Eiffel tower is located is rue de la Chambre d

In [43]:
run_prompt("The capital of Italy is ")

Prompt: The capital of Italy is 
r=42: <|begin_text|>The capital of Italy is ​the city of Rome, which is located in the central part of the country. The city is
r=32: <|begin_text|>The capital of Italy is ​the city of Rome, which is located in the central part of the country. The city is
r=10: <|begin_text|>The capital of Italy is ​Rome. The city is located in the center of the country, on the banks of the
r=7: <|begin_text|>The capital of Italy is ​Rome. It is located in the central region of the country. The city is located on
r=5: <|begin_text|>The capital of Italy is ​the city of Rome. The city is located in the central part of the country, in the
r=3: <|begin_text|>The capital of Italy is ​the largest city in the world. It is the second largest city in the world after the United
r=2: <|begin_text|>The capital of Italy is ​the capital of the country. The capital is the city of Rome. The capital is the city
Control: <|begin_text|>The capital of Italy is ​the city of Rome. It is the 

In [46]:
run_prompt("Milan is located in the country where the capital is ")

Prompt: Milan is located in the country where the capital is 
r=42: <|begin_text|>Milan is located in the country where the capital is ​Rome. The city is located in the north-eastern part of Italy, in the
r=32: <|begin_text|>Milan is located in the country where the capital is ​Rome. The city is located in the central part of Italy, in the region of Em
r=10: <|begin_text|>Milan is located in the country where the capital is ​Rome. The city is located in the central part of Italy, in the region of Em
r=7: <|begin_text|>Milan is located in the country where the capital is ​Milan. The city is located in the region of Lombardy. Milan is
r=5: <|begin_text|>Milan is located in the country where the capital is ​the city of Milan.
The city of Milan is located in the country where the
r=3: <|begin_text|>Milan is located in the country where the capital is ​the largest city in the country.
The city is located in the province of Šó
r=2: <|begin_text|>Milan is located in the country where the capi

# Chat template

In [30]:
messages = []
#messages.append({"role": "system", "content" : "You are a helpful assistant."})
messages.append({"role": "user", "content" : "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"})
chat_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(chat_input)
input_ids = tokenizer.encode(chat_input, return_tensors="pt", add_special_tokens=False).to(device)

output_ids = model.generate(input_ids, config, num_steps=40, tokenizer=tokenizer)


<|begin_text|><|begin_header|>user<|end_header|>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|end_turn|><|begin_header|>Huginn<|end_header|>


