## Excercise 7.2 Instruction and input masking

In [1]:
import json
import os
import urllib


def download_and_load_file(file_path, url):

    if not os.path.exists(file_path):
        with urllib.request.urlopen(url) as response:
            text_data = response.read().decode("utf-8")
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)

    # The book originally contained this unnecessary "else" clause:
    #else:
    #    with open(file_path, "r", encoding="utf-8") as file:
    #        text_data = file.read()

    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    return data


file_path = "instruction-data.json"
url = (
    "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch"
    "/main/ch07/01_main-chapter-code/instruction-data.json"
)

data = download_and_load_file(file_path, url)
print("Number of entries:", len(data))

Number of entries: 1100


In [2]:
# Alpaca-style input format

def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describe a task. "
        f"Write a response that appropriately complete the request."
        f"\n\n### Instruction: \n{entry['instruction']}"
    )

    input_text = (
        f"\n\n### Input: \n{entry['input']}" if entry['input'] else ""
    )

    return instruction_text + input_text

In [3]:
train_portion = int(len(data) * 0.85)
test_portion = int(len(data) * 0.1)
val_portion = len(data) - train_portion - test_portion

train_data = data[:train_portion]
test_data = data[train_portion: train_portion+test_portion]
val_data = data[train_portion+test_portion:]

print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))
print("Test set length:", len(test_data))

Training set length: 935
Validation set length: 55
Test set length: 110


In [4]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))

from previous_chapters import (
    generate,
    text_to_token_ids,
    token_ids_to_text
)

[50256]


In [5]:
import torch
from torch.utils.data import Dataset

class InstructionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.encoded_texts = []
        self.len_instruct_input = []

        for entry in data:
            instruction_plus_input = format_input(entry)
            response_text = f"\n\n### Response:\n{entry['output']}"
            full_text = instruction_plus_input + response_text

            self.encoded_texts.append(tokenizer.encode(full_text))
            self.len_instruct_input.append(len(tokenizer.encode(instruction_plus_input)))

    def __getitem__(self, index):
        return self.len_instruct_input[index], self.encoded_texts[index]
    
    def __len__(self):
        return len(self.data)

In [6]:
def custom_collate_fn(batch, pad_token_id=50256, ignore_index=-100, allowed_max_length=None, device="cpu"):
    # Find the longest sequence in the batch
    batch_max_length = max(len(item)+1 for _, item in batch)

    # Pad and prepare inputs and targets
    inputs_lst, targets_lst = [], []

    for instruction_len, item in batch:
        new_item = item.copy()
        # Add an <|endoftext|> token
        new_item += [pad_token_id]
        # Pad sequences to max_length
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])  # Truncate the last token for inputs
        targets = torch.tensor(padded[1:])  # Shift +1 to the right for targets

        # New: Replace all but the first padding tokens in targets by ignore_index
        mask = targets == pad_token_id
        indices = torch.nonzero(mask).squeeze()
        if indices.numel() > 1:
            targets[indices[1:]] = ignore_index

        targets[:instruction_len] = -100

        # New: Optionally truncate to maximum sequence length
        if allowed_max_length is not None:
            inputs = inputs[:allowed_max_length]
            targets = targets[:allowed_max_length]
        
        inputs_lst.append(inputs)
        targets_lst.append(targets)

    # Convert list of inputs and targets to tensors and transfer to target device
    inputs_tensor = torch.stack(inputs_lst).to(device)
    targets_tensor = torch.stack(targets_lst).to(device)
    print(inputs_tensor.shape, targets_tensor.shape)
    return inputs_tensor, targets_tensor

In [7]:
device = torch.device("xpu" if torch.xpu.is_available() else "cpu")
print(device)

from functools import partial

customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)

xpu


In [8]:
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

torch.manual_seed(123)

train_dataset = InstructionDataset(train_data, tokenizer)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

val_dataset = InstructionDataset(val_data, tokenizer)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

test_dataset = InstructionDataset(test_data, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [9]:
print("Train loader:")
for inputs, targets in train_loader:
    print(inputs.shape, targets.shape)

Train loader:
torch.Size([8, 63]) torch.Size([8, 63])
torch.Size([8, 63]) torch.Size([8, 63])
torch.Size([8, 77]) torch.Size([8, 77])
torch.Size([8, 77]) torch.Size([8, 77])
torch.Size([8, 75]) torch.Size([8, 75])
torch.Size([8, 75]) torch.Size([8, 75])
torch.Size([8, 70]) torch.Size([8, 70])
torch.Size([8, 70]) torch.Size([8, 70])
torch.Size([8, 67]) torch.Size([8, 67])
torch.Size([8, 67]) torch.Size([8, 67])
torch.Size([8, 73]) torch.Size([8, 73])
torch.Size([8, 73]) torch.Size([8, 73])
torch.Size([8, 82]) torch.Size([8, 82])
torch.Size([8, 82]) torch.Size([8, 82])
torch.Size([8, 69]) torch.Size([8, 69])
torch.Size([8, 69]) torch.Size([8, 69])
torch.Size([8, 64]) torch.Size([8, 64])
torch.Size([8, 64]) torch.Size([8, 64])
torch.Size([8, 77]) torch.Size([8, 77])
torch.Size([8, 77]) torch.Size([8, 77])
torch.Size([8, 63]) torch.Size([8, 63])
torch.Size([8, 63]) torch.Size([8, 63])
torch.Size([8, 70]) torch.Size([8, 70])
torch.Size([8, 70]) torch.Size([8, 70])
torch.Size([8, 69]) torch.

In [10]:
print(inputs[0])
print(token_ids_to_text(inputs[0], tokenizer))

tensor([21106,   318,   281, 12064,   326,  6901,   257,  4876,    13, 19430,
          257,  2882,   326, 20431,  1844,   262,  2581,    13,   198,   198,
        21017, 46486,    25,   220,   198, 30003,  6525,   262,  6827,  1262,
          257,   985,   576,    13,   198,   198, 21017, 23412,    25,   220,
          198,   464,  5156,   318,   845, 13779,    13,   198,   198, 21017,
        18261,    25,   198,   464,  5156,   318,   355, 13779,   355,   257,
         4936,    13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256], device='xpu:0')
Below is an instruction that describe a task. Write a response that appropriately complete the request.

### Instruction: 
Rewrite the sentence using a simile.

### Input: 
The baby is very cute.

### Response:
The baby is as cute as a button.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


In [11]:
print(targets[0])
print(token_ids_to_text(targets[0][targets[0] != -100], tokenizer))

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,   198, 21017, 18261,
           25,   198,   464,  5156,   318,   355, 13779,   355,   257,  4936,
           13, 50256,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100], device='xpu:0')

### Response:
The baby is as cute as a button.<|endoftext|>
