In [3]:
# empty gpu cache
import torch
torch.cuda.empty_cache()
import gc
gc.collect()


31

In [4]:
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from transformers import default_data_collator



In [5]:
notebook_dir = os.getcwd()
model_dir = os.path.join(notebook_dir, "..", "..", "..", "..", "local-models/Llama-3.2-1B")
device = "cuda" if torch.cuda.is_available() else "cpu"


In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    #torch_dtype=torch.float16,  # Use half precision to reduce memory usage
    low_cpu_mem_usage=True,     # Load the model with memory optimization
    device_map=device          # Automatically handle device placement
)
tokenizer = AutoTokenizer.from_pretrained(model_dir)


: 

In [None]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id


In [None]:
tokenizer.padding_side = "left"

In [None]:
stories = [
    "A lonely robot found a broken music box in an abandoned city. As he fixed it, the melody attracted other robots. Together, they created the city's first robot orchestra.",
    "The last seed on Earth was planted by a child in her grandmother's garden. Against all odds, it grew into a magical tree that produced seeds of every plant that had been lost.",
    "In a world where dreams were visible as floating bubbles, a young girl discovered she could weave them into blankets. Her creations brought comfort to those suffering from nightmares.",
    "An old lighthouse keeper discovered that his beacon didn't guide ships, but rather lost stars back to their constellations. Each night, he helped rebuild the night sky.",
    "Deep in the digital forest, a virus learned to heal corrupted files instead of destroying them. Other programs began calling it the Digital Doctor.",
    "A time-traveling mailman accidentally delivered letters to the wrong centuries. The resulting mix-ups created unexpected friendships across time.",
    "The last bookstore on Mars housed a librarian who could bring characters to life by reading aloud. She used this gift to help homesick colonists feel less alone."
]

summaries = [
    "A robot repairs a music box and builds community through music.",
    "A child's last seed miraculously restores Earth's lost plants.",
    "A girl turns visible dreams into comforting blankets for nightmare sufferers.",
    "A lighthouse keeper helps lost stars find their way back to constellations.",
    "A benevolent virus becomes known for healing corrupted files.",
    "A mailman's time-travel mistakes lead to cross-century friendships.",
    "A Martian librarian uses her power to comfort colonists with living stories."
]

alternative_stories = [
    "In a forgotten city, a solitary robot stumbled upon an ancient piano. As it played, the harmonious notes summoned other robots, and together they formed a symphony that echoed through the empty streets.",
    "A young girl planted a mysterious seed in a barren land. To everyone's amazement, it sprouted into a tree that bore fruits of every extinct plant, reviving the world's lost flora.",
    "In a realm where dreams floated like clouds, a girl learned to capture them in jars. Her bottled dreams provided solace to those haunted by restless nights.",
    "An old lighthouse keeper discovered that his beacon didn't guide ships, but rather lost stars back to their constellations. Each night, he helped rebuild the night sky.",

]

alternative_summaries = [
    "A robot finds a piano and unites others through music.",
    "A girl's seed grows into a tree that revives extinct plants.",
    "A girl captures dreams to comfort those with nightmares.",
    "A lighthouse keeper helps lost stars find their way back to constellations.",

]
pairs_train = pd.DataFrame({"text": stories, "target": summaries})
pairs_val = pd.DataFrame({"text": alternative_stories, "target": alternative_summaries})
pairs_train = Dataset.from_pandas(pairs_train)
pairs_val = Dataset.from_pandas(pairs_val)
pairs = DatasetDict({"train": pairs_train, "validation": pairs_val})


In [None]:
pairs

In [None]:

def tokenize_function(example, tokenizer):
    inputs = tokenizer(example["text"], add_special_tokens=True) # only gonna add bos
    targets = tokenizer(example["target"], add_special_tokens=False) # we will manually add eos
    for i in range(len(inputs["input_ids"])):
        sample_input_ids =  inputs["input_ids"][i] 
        sample_label_input_ids = targets["input_ids"][i] + [tokenizer.eos_token_id]
        inputs["input_ids"][i] = sample_input_ids + sample_label_input_ids
        targets["input_ids"][i] = [-100] * len(sample_input_ids) + sample_label_input_ids
        inputs["attention_mask"][i] = [1] * len(inputs["input_ids"][i])
    inputs["labels"] = targets["input_ids"]
    # input_ids, attention_mask, and labels are all the same length for a given sample, but not across samples
    # so we need to pad to max length from left side

    max_length = max([len(x) for x in inputs["input_ids"]])
    # add padding tokens to the left side of the input ids, attention mask, and labels
    for i in range(len(inputs["input_ids"])):
        inputs["input_ids"][i] = ([tokenizer.pad_token_id] * 
                                (max_length - len(inputs["input_ids"][i])) + 
                                inputs["input_ids"][i])
        inputs["attention_mask"][i] = ([0] * (max_length - len(inputs["attention_mask"][i])) +
                                    inputs["attention_mask"][i])
        inputs["labels"][i] = ([-100] * (max_length - len(inputs["labels"][i])) +
                                inputs["labels"][i])
        
    return inputs


In [None]:
pairs_dataset = pairs.map(
        lambda x: tokenize_function(x, tokenizer=tokenizer),
        batched=True,
        remove_columns=pairs['train'].column_names,
        load_from_cache_file=False,
        desc="Running tokenizer on train dataset",
    )

In [None]:
for split in pairs_dataset:
    print(split)
    for i in range(len(pairs_dataset[split]["input_ids"])):
        print(pairs_dataset[split]["input_ids"][i])   
        print(len(pairs_dataset[split]  ["input_ids"][i]))
        print(pairs_dataset[split]["attention_mask"][i])  
        print(len(pairs_dataset[split]["attention_mask"][i]))
        print(pairs_dataset[split]["labels"][i])
        print(len(pairs_dataset[split]["labels"][i]))
        print("-"*100)


In [None]:
train_dataloader = DataLoader(
    pairs_dataset['train'], shuffle=True, collate_fn=default_data_collator, batch_size=2
)
dev_dataloader = DataLoader(
    pairs_dataset['validation'], shuffle=True, collate_fn=default_data_collator, batch_size=2
)

In [None]:
for batch in dev_dataloader:
    print("-"*100)
    print("train batch")
    for i in range(len(batch["input_ids"])):
        print(f"number {i}")
        print(batch["input_ids"][i])   
        print(len(batch["input_ids"][i]))
        print(batch["attention_mask"][i])  
        print(len(batch["attention_mask"][i]))
        print(batch["labels"][i])
        print(len(batch["labels"][i]))
        print("-"*100)
