<a href="https://colab.research.google.com/github/jxtngx/torchtune-cookbook/blob/main/summarization/L1_Summarization_with_Llama3_2_1B_and_torchtune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Make certain to switch to a T4 GPU Runtime!

First, let's follow the installation instructions given in the torchtune README

In [None]:
%%bash
pip install torch torchvision torchao -q
pip install git+https://github.com/pytorch/torchtune.git -q

Now that installation is complete, let's download the Llama 3.2 1B Instruct Model

> make certain to set the output dir to <br/>
> ${PWD}/Meta-Llama-3.1-8B-Instruct

In [None]:
%%bash
tune download meta-llama/Llama-3.2-1B-Instruct \
--output-dir ${PWD}/Llama-3.2-1B-Instruct \
--ignore-patterns "original/consolidated.00.pth" \
--hf-token <<YOUR_HF_KEY>>

Now, let's load the data from the Lesson 0 notebook on data acquisition and preprocessing.

In [None]:
import json

In [None]:
with open("/content/drive/MyDrive/intelligent-agents/intelligent_agent.json", "r") as fp:
    data = json.load(fp)

In [None]:
# check that the key is "Intelligent Agent"
data.keys()

In [None]:
# remind ourselves that the article key should have as an item, another dict
# with "url" and "body" keys
data["Intelligent Agent"].keys()

In [None]:
# the text data we are after is in the body key
# let's save that to a variable named msg
# and be certain to add the instructions to summarize the article
# before using torchtune.generate
msg = "please summarize the following article: "+ data["Intelligent Agent"]["body"]

It's time to run generation. We can use torchtune for this.

> `torchtune.generation.generate` returns token IDs and logits. the method is not designed as a chat interface.

In [None]:
from pathlib import Path
from time import perf_counter

import torch
import torchtune
import torchao
from torchao.quantization.quant_api import int8_weight_only
from torchtune.generation import generate
from torchtune.training.checkpointing import FullModelHFCheckpointer
from torchtune.models.llama3_2 import llama3_2_1b
from torchtune.models.llama3 import llama3_tokenizer

First, let's load the checkpoint we downloaded from Hugging Face.

In [None]:
# create a checkpointer
ckptr = FullModelHFCheckpointer(
    "/content/Llama-3.2-1B-Instruct/",
    checkpoint_files = ["model.safetensors"],
    model_type="LLAMA3_2",
    output_dir="/content/output/Llama-3.2-1B-Instruct"
    )

In [None]:
# load the checkpoint
# note: this returns a model state dict that needs to be loaded to a model in the following cell
model_sd = ckptr.load_checkpoint()

In [None]:
# instantiate a model and load the state_dict
model = llama3_2_1b()
model.load_state_dict(model_sd["model"])

In [None]:
# qauntize the model with torchao
torchao.quantize_(model, int8_weight_only(group_size=32))

In [None]:
type(model)

Now, let's load the tokenizer.

> Llama 3 models in torchtune reuse the `llama3_tokenizer`

In [None]:
tokenizer = llama3_tokenizer("/content/Llama-3.2-1B-Instruct/original/tokenizer.model")

Let's create a basic prompt for a first pass at generation:

In [None]:
prompt = tokenizer.encode("Hi my name is")

In [None]:
rng = torch.Generator(device="cuda")
rng.manual_seed(42)

start = perf_counter()
output, logits = generate(
    model,
    torch.tensor(prompt),
    max_generated_tokens=100,
    pad_id=0,
    rng=rng
)
end = perf_counter()

In [None]:
f"generation took {(end-start)/60} minutes"

Let's inspect the output token IDs:

In [None]:
output

And finally, let's decode the token IDs with the tokenizer:

In [None]:
tokenizer.decode(output[0].tolist(), truncate_at_eos=False)

#TODO Create summarization example