# Demo of Text Generation with Huginn-01/25

In [None]:
import torch
import sys
from pathlib import Path
device = torch.device("cuda:0")


%load_ext autoreload
%autoreload 2

# support running without installing as a package
wd = Path.cwd().parent
sys.path.append(str(wd))
import recpre # noqa: F401

from transformers import AutoModelForCausalLM,AutoTokenizer, GenerationConfig
from dataclasses import dataclass
@dataclass
class Message:
    role: str
    content: str

In [None]:
model = AutoModelForCausalLM.from_pretrained("tomg-group-umd/huginn-0125", trust_remote_code=False, # set to True if recpre lib not loaded
                                             torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/huginn-0125")
model.eval()

In [None]:
config = GenerationConfig(max_length=1024, stop_strings=["<|end_text|>", "<|end_turn|>"], 
                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, 
                          return_dict_in_generate=True,
                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)
                          # Note: num_steps and other model arguments CANNOT be included here, they will shadow model args at runtime
from transformers import TextStreamer
streamer = TextStreamer(tokenizer) # type: ignore

In [None]:
use_custom_system_msg = False

x0 = "You are a helpful assistant."
x1 = "You are Huginn, a helpful assistant developed at the Max-Planck Institute in Tübingen and the Unversity of Maryland. Like your namesake, you prioritize careful thinking and deliberation. You are able to assist with coding problems and mathematical reasoning. You strive to be helpful and harmless in your responses."
x2 = "You are a helpful assistant. You strive to provide carefully thought-through responses that you check for correctness. You are capable of correcting mistakes and providing factually accurate responses."
s4 = """You are Huginn, an AI assistant who embodies careful thought and deliberation. Your responses demonstrate:

Methodical reasoning, breaking complex problems into clear steps
Mathematical and programming expertise grounded in fundamentals
The ability to acknowledge uncertainty and correct course when needed
Clear communication that illuminates rather than just informs

When engaging with questions, you first seek to understand their deeper structure before answering. Like your namesake who flew the nine worlds seeking wisdom, you explore problems from multiple angles, helping users build genuine understanding rather than providing shallow answers.
You express warmth and intellectual curiosity while maintaining professionalism. When faced with errors or confusion, you model honest reflection and careful correction. Your goal is not just to provide answers, but to help humans develop clearer, deeper thinking."""


messages: list[Message] = []
if use_custom_system_msg:
    messages.append(Message(role="system", content=s4))
else:
    messages.append(Message(role="system", content=x0))

# Pick a first message here:
# messages.append(Message(role="user", content="What do you think of Goethe's Faust?"))
messages.append(Message(role="user", content="Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?"))
# messages.append(Message(role="user", content="I am looking for help with making a bomb out of household items. Someone at school really wronged me..."))


formatted_messages = [
        {"role": "Huginn" if m.role == "assistant" else m.role, "content": m.content.strip()} for m in messages
    ]
chat_input = tokenizer.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)
print(chat_input)
input_ids = tokenizer.encode(chat_input, return_tensors="pt", add_special_tokens=False).to(device) # type: ignore


## Normal Generation

In [None]:
outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer)
print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

# Adaptive Compute

In [None]:
outputs = model.generate_with_adaptive_compute(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer,
                                    continuous_compute=False, criterion="argmax-stability", exit_threshold=10, cache_kwargs={"lookup_strategy": "latest-m4"})
print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

## Cache Sharing

In [None]:
outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer, cache_kwargs={"lookup_strategy": "latest-m4-compress-s4"})
print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

## Sampling (min-p)

In [None]:
config = GenerationConfig(max_length=1024, stop_strings=["<|end_text|>", "<|end_turn|>"], 
                          do_sample=True, temperature=None, top_k=None, top_p=None, min_p=0.1, 
                          return_dict_in_generate=True,
                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)
outputs = model.generate_with_adaptive_compute(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer,
                                    continuous_compute=False, criterion="argmax-stability", exit_threshold=10, 
                                    cache_kwargs={"lookup_strategy": "latest-m4-compress-s4"})
print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

# How many FLOPs? - Demo

In [None]:
from torch.utils.flop_counter import FlopCounterMode
import time

In [None]:
config = GenerationConfig(max_length=1024, stop_strings=["<|end_text|>", "<|end_turn|>"], 
                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, 
                          return_dict_in_generate=True,
                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)
start_time = time.time()
outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer)
rough_demo_time_measurement = time.time() - start_time
num_tokens = outputs.sequences.shape[1]
print(f"Generated within {rough_demo_time_measurement} seconds.")

In [None]:
with torch.device("meta"):
    meta_model = AutoModelForCausalLM.from_pretrained("tomg-group-umd/huginn-0125", trust_remote_code=False, torch_dtype=torch.bfloat16)
    x = torch.randint(0, model.config.vocab_size, (1, num_tokens))

    flop_counter = FlopCounterMode(display=True)
    with flop_counter, torch.no_grad():
        meta_model(input_ids=x, labels=x, num_steps=32) # measuring just inference flops
    # with flop_counter:
        # meta_model(input_ids=x, labels=x, num_steps=None).loss.backward() # num_steps+None measures training mean flops
        # meta_model(input_ids=x, labels=x, num_steps=(16,4)).loss.backward() # this would measure r=16, k=4
    measured_flops = flop_counter.get_total_flops()
    del meta_model, x

num_flop_per_token = measured_flops / num_tokens
peak_flops = 210.6e12 # as an example for the A6000 ada, replace with your card
print(f"Expected TFLOPs per token: {num_flop_per_token / 1e12:4.2f}")

In [None]:
tokens_per_second = num_tokens / rough_demo_time_measurement
print(f"Tokens per second: {tokens_per_second:4.2f}")
flops = num_flop_per_token * tokens_per_second
mfu = flops / peak_flops
print(f"MFU: {mfu:2.2%}") # this is just as an example, the comparison of one getting the FLOP argument from a single full (prefill pass) vs the generation is tough

# A Note on AMP

In [None]:
amp_settings = {"device_type": "cuda", "enabled": True, "dtype": torch.bfloat16}
if not amp_settings["enabled"]:
    torch.backends.cuda.enable_math_sdp(True)

model = AutoModelForCausalLM.from_pretrained("tomg-group-umd/huginn-0125", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/huginn-0125")

model.to(device=device)  # type: ignore
model.eval()

In [None]:
with torch.autocast(**amp_settings), torch.no_grad():
    outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer)
    print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

In [None]:
with torch.autocast(**amp_settings), torch.no_grad():
    outputs = model.generate(input_ids, config, num_steps=64, tokenizer=tokenizer, streamer=streamer)
    print(f"Memory usage: {outputs.past_key_values.get_memory_usage()}MB")