ooks

## Import model and tokenizer


In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup


model_id = "karpathy/nanochat-d32"
revision = "refs/pr/1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=revision,
    torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
).to(device)


## Setup LoRA


In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=1,
    lora_alpha=2,
    lora_dropout=0.00,
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


## Demo the model


In [None]:
print("=" * 80)
print("TEST 1: Plain Autoregressive Prompt")
print("=" * 80)
prompt = "The Eiffel Tower stands in Paris and"
test_inputs = tokenizer(prompt, return_tensors="pt").to(device)


with torch.no_grad():
    test_outputs = model.generate(
        **test_inputs,
        max_new_tokens=64,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

generated_tokens = test_outputs[0, test_inputs["input_ids"].shape[1] :]
print(f"Prompt: {prompt}")
print(f"\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}")
print("=" * 80)


In [None]:
print("=" * 80)
print("TEST 2: Chat Template")
print("="*80)
conversation = [
    {"role": "user", "content": "What is the capital of France?"},
]

inputs = tokenizer.apply_chat_template(
    conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)

print(f"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}")
print(f"Input IDs: {inputs['input_ids'][0].tolist()}")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=False
    )

generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
print(f"\nGenerated: {tokenizer.decode(generated_tokens)}")
print("=" * 80)


## Dataset
