# Finetuning to follow instructions

<img src="https://camo.githubusercontent.com/6736ab7968f8da6bd6fc747de22ef9afa9d840373749005ce3e96fc6ead7ed8c/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f636861707465722d6f766572766965772d312e776562703f31" width=700>

## Stage 1: Preparing the dataset

### 1. Dataset download and preparation

In [1]:
import json
import os
import urllib

In [2]:
def download_and_load_file(file_path, url):

    if not os.path.exists(file_path):
        with urllib.request.urlopen(url) as response:
            text_data = response.read().decode("utf-8")
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)

    # The book originally contained this unnecessary "else" clause:
    #else:
    #    with open(file_path, "r", encoding="utf-8") as file:
    #        text_data = file.read()

    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    return data


file_path = "instruction-data.json"
url = (
    "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch"
    "/main/ch07/01_main-chapter-code/instruction-data.json"
)

data = download_and_load_file(file_path, url)
print("Number of entries:", len(data))

Number of entries: 1100


In [3]:
data[50:53]

[{'instruction': 'Identify the correct spelling of the following word.',
  'input': 'Ocassion',
  'output': "The correct spelling is 'Occasion.'"},
 {'instruction': "What type of figurative language is used in 'She is the apple of my eye'?",
  'input': '',
  'output': 'The figurative language used is a metaphor.'},
 {'instruction': 'Correct the spelling error in the sentence.',
  'input': 'I will atend the meeting tomorrow.',
  'output': "The correct spelling is 'attend', not 'atend'."}]

---

<img src="https://camo.githubusercontent.com/56327c274257475f53fbb0a25fac50e703cbd67af30ff930eb3611c3356f6da6/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f70726f6d70742d7374796c652e776562703f31" width=700>

Left is ***Alpaca*** prompt style and right is ***Phi-3*** prompt style developped by Microsoft.

We will use Alpaca prompt style.

In [4]:
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""

    return instruction_text + input_text

In [5]:
print(format_input(data[50]))
print(f"\n### Response:\n{data[50]['output']}")

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Identify the correct spelling of the following word.

### Input:
Ocassion

### Response:
The correct spelling is 'Occasion.'


In [6]:
print(format_input(data[51]))
print(f"\n### Response:\n{data[51]['output']}")

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What type of figurative language is used in 'She is the apple of my eye'?

### Response:
The figurative language used is a metaphor.


---

In [56]:
train_portion = int(len(data) * .85)
test_portion = int(len(data) * .1)
val_portion = len(data) - train_portion - test_portion

train_data = data[:train_portion]
test_data = data[train_portion:train_portion + test_portion]
val_data = data[train_portion + test_portion:]

f'{train_portion=}, {test_portion=}, {val_portion=}'

'train_portion=935, test_portion=110, val_portion=55'

### 2. Batching the dataset

<img src="https://camo.githubusercontent.com/85ba6fcb03b3337a5d339092f86afe331e90cc77985da2cacff8f66cb26c4f59/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f64657461696c65642d6261746368696e672e776562703f31" width=600>

#### 2.1 & 2.2 - Format dataset entry

<img src="https://camo.githubusercontent.com/b9b0cd632b51d3b6490e2ada4cf470a1ca61ff6649370c0a7cdcafe2bf7dbc77/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f707265746f6b656e697a696e672e77656270" width=75%>

In [8]:
import torch
from torch.utils.data import Dataset

In [9]:
class InstructionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.encoded_text = []
        
        for entry in data:
            instruction_plus_input = format_input(entry)
            response_text = f"\n\n### Response:\n{entry['output']}"
            
            full_text = instruction_plus_input + response_text
            self.encoded_text.append(tokenizer.encode(full_text))
            
    def __getitem__(self, index):
        return self.encoded_text[index]

    def __len__(self):
        return len(self.data)

#### 2.3 - Padding with `50256`

<img src="https://camo.githubusercontent.com/3f1bcae9afed840d168ac596c1b10eb9f29a2c96938fe4ec86f5d50008829a2e/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f70616464696e672e77656270" width=60%>

In [10]:
def custom_collate_draft_1(batch, pad_token_id=50256, device='cpu'):
    batch_max_length = max(len(item) + 1 for item in batch)
    inputs_list = []
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        
        inputs = torch.tensor(padded[:-1])
        inputs_list.append(inputs)
        
    inputs_tensor = torch.stack(inputs_list).to(device)
    
    return inputs_tensor

In [11]:
inputs_1 = [0, 1, 2, 3, 4]
inputs_2 = [5, 6]
inputs_3 = [7, 8, 9]
batch = (
    inputs_1,
    inputs_2,
    inputs_3
)
print(custom_collate_draft_1(batch))

tensor([[    0,     1,     2,     3,     4],
        [    5,     6, 50256, 50256, 50256],
        [    7,     8,     9, 50256, 50256]])


#### 2.4 - Create target token IDs

In [12]:
def custom_collate_draft_2(batch, pad_token_id=50256, device='cpu'):
    batch_max_length = max(len(item) + 1 for item in batch)
    inputs_list  = []
    targets_list = []
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])
        inputs_list.append(inputs)
        targets_list.append(targets)
        
    inputs_tensor  = torch.stack(inputs_list).to(device)
    targets_tensor = torch.stack(targets_list).to(device)
    
    return inputs_tensor, targets_tensor

In [13]:
custom_collate_draft_2(batch)

(tensor([[    0,     1,     2,     3,     4],
         [    5,     6, 50256, 50256, 50256],
         [    7,     8,     9, 50256, 50256]]),
 tensor([[    1,     2,     3,     4, 50256],
         [    6, 50256, 50256, 50256, 50256],
         [    8,     9, 50256, 50256, 50256]]))

#### 2.5 - Replace padding tokens with placeholders

<img src="https://camo.githubusercontent.com/b8fcb2f5ace86849d40ea7ca7fef93c12614ebbae302783c79a97342fd777200/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f69676e6f72652d696e6465782e776562703f31" width=75%>

In [155]:
def custom_collate_fn(batch, pad_token_id=50256, ignore_index=-100, allowed_max_length=None, device='cpu'):
    batch_max_length = max(len(item) + 1 for item in batch)
    inputs_list  = []
    targets_list = []
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])
        
        mask = targets == pad_token_id
        indices = torch.nonzero(mask).squeeze()
        if indices.numel() > 1:
            targets[indices[1:]] = ignore_index
            
        if allowed_max_length is not None:
            inputs = inputs[:allowed_max_length]
            targets = targets[:allowed_max_length]
        
        inputs_list.append(inputs)
        targets_list.append(targets)
        
    inputs_tensor  = torch.stack(inputs_list).to(device)
    targets_tensor = torch.stack(targets_list).to(device)
    
    return inputs_tensor, targets_tensor

In [156]:
custom_collate_fn(batch)

(tensor([[    0,     1,     2,     3,     4],
         [    5,     6, 50256, 50256, 50256],
         [    7,     8,     9, 50256, 50256]]),
 tensor([[    1,     2,     3,     4, 50256],
         [    6, 50256,  -100,  -100,  -100],
         [    8,     9, 50256,  -100,  -100]]))

Here we mask the padding token IDs with `-100` because it is a value that is ignored by the `cross_entropy` function!

Additionnally, we can use the same masking on the instructions token IDs.

<img src="https://camo.githubusercontent.com/e5061d1720dc8a56c789519562aa425a777971cce962c62a9ceab8b8c77fe627/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f6d61736b2d696e737472756374696f6e732e776562703f31" width=75%>

I will probably do it after training to see if it improves the model.

### 3. Creating data loaders

In [16]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
device

device(type='mps')

In [47]:
from functools import partial
# this allows us to create a new version of a function
# with parameters that are prefilled

In [48]:
customized_collate_fn = partial(
    custom_collate_fn,
    device=device,
    allowed_max_length=1024
)

In [49]:
from torch.utils.data import DataLoader
import tiktoken

In [50]:
num_workers = 0
batch_size = 8

tokenizer = tiktoken.get_encoding('gpt2')
torch.manual_seed(123)

train_dataset = InstructionDataset(train_data, tokenizer)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

In [51]:
val_dataset = InstructionDataset(val_data, tokenizer)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)
test_dataset = InstructionDataset(test_data, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=customized_collate_fn,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [22]:
print('Train loader:')
for inputs, targets in train_loader:
    print('\t', inputs.shape)
    print('\t', targets.shape)
    print('\t ...')
    break

Train loader:
	 torch.Size([8, 61])
	 torch.Size([8, 61])
	 ...


---

## Stage 2: Fine-tuning the LLM

#### 4. Loading a pretrained LLM

Instead of loading the smallest 124 million parameter model,<br>
we load the **medium version** with 355 million parameters since the 124 million model is too small for achieving qualitatively reasonable results via instruction finetuning.

In [23]:
from llms_from_scratch.ch04 import GPTModel
from llms_from_scratch.ch05 import download_and_load_gpt2, load_weights_into_gpt


BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

CHOOSE_MODEL = "gpt2-medium (355M)"

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(
    model_size=model_size,
    models_dir="gpt2"
)

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval();

File already exists and is up-to-date: gpt2/355M/checkpoint
File already exists and is up-to-date: gpt2/355M/encoder.json
File already exists and is up-to-date: gpt2/355M/hparams.json
File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/355M/model.ckpt.index
File already exists and is up-to-date: gpt2/355M/model.ckpt.meta
File already exists and is up-to-date: gpt2/355M/vocab.bpe


In [24]:
torch.manual_seed(123)
input_text = format_input(val_data[0])
print(input_text)

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Convert the active sentence to passive: 'The chef cooks the meal every day.'


In [28]:
from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text

In [29]:
%%time
token_ids = generate(
    model=model,
    idx=text_to_token_ids(input_text, tokenizer),
    max_new_tokens=35,
    context_size=BASE_CONFIG["context_length"],
    eos_id=50256,
)

CPU times: user 28 s, sys: 1.46 s, total: 29.5 s
Wall time: 7.45 s


In [31]:
generated_text = token_ids_to_text(token_ids, tokenizer)
print(generated_text[len(input_text):].strip())

### Response:

The chef cooks the meal every day.

### Instruction:

Convert the active sentence to passive: 'The chef cooks the


#### 5. Instruction fine-tuning the LLM

In [32]:
from llms_from_scratch.ch05 import calc_loss_loader, train_model_simple

In [35]:
model.to(device);

In [36]:
torch.manual_seed(123)
with torch.no_grad():
    train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)
    val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)


train_loss, val_loss

(3.825910711288452, 3.761934995651245)

---

In [42]:
# this will take a long time on my personal machine
# run this on a powerful gpu or on the cloud
%time
if False:
    torch.manual_seed(123)
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=0.00005, weight_decay=0.1
    )
    num_epochs = 2
    
    train_losses, val_losses, tokens_seen = train_model_simple(
        model, train_loader, val_loader, optimizer, device,
        num_epochs=num_epochs, eval_freq=5, eval_iter=5,
        start_context=format_input(val_data[0]), tokenizer=tokenizer
    )
else:
    # loading a model trained for 2 epochs
    model_state_dict = torch.load('instruction_model.pth', map_location=device)
    model.load_state_dict(model_state_dict)

CPU times: user 4 μs, sys: 14 μs, total: 18 μs
Wall time: 21.9 μs


<img src="plot2.png">

## Stage 3: Evaluating the LLM

### 7. Extracting responses

In [57]:
%%time
torch.manual_seed(123)

for entry in test_data[:1]:
    input_text = format_input(entry)
    token_ids = generate(
        model=model,
        idx=text_to_token_ids(input_text, tokenizer).to(device),
        max_new_tokens=256,
        context_size=BASE_CONFIG["context_length"],
        eos_id=50256
    )
    generated_text = token_ids_to_text(token_ids, tokenizer)

    response_text = (
        generated_text[len(input_text):]
        .replace("### Response:", "")
        .strip()
    )
    print(input_text)
    print(f"\nCorrect response:\n>> {entry['output']}")
    print(f"\nModel response:\n>> {response_text.strip()}")
    print("-------------------------------------")

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Rewrite the sentence using a simile.

### Input:
The car is very fast.

Correct response:
>> The car is as fast as lightning.

Model response:
>> The car is as fast as a bullet.
-------------------------------------
CPU times: user 1.25 s, sys: 2.59 s, total: 3.84 s
Wall time: 11.6 s


### 8. Qualitative evaluation

Model evaluation is not as straightforward as in the previous chapter, where we just had to calculate the percentage of correct spam/non-spam class labels to obtain the classification accuracy.

In practice, instruction-finetuned LLMs such as chatbots are evaluated via multiple approaches

- short-answer and multiple choice benchmarks such as MMLU ("Measuring Massive Multitask Language Understanding", https://arxiv.org/abs/2009.03300), which test the knowledge of a model
- human preference comparison to other LLMs, such as LMSYS chatbot arena (https://arena.lmsys.org)
- automated conversational benchmarks, where another LLM like GPT-4 is used to evaluate the responses, such as AlpacaEval (https://tatsu-lab.github.io/alpaca_eval/)

In the next section, we will use an approach similar to AlpacaEval and use another LLM to evaluate the responses of our model; however, we will use our own test set instead of using a publicly available benchmark dataset

For this, we add the model response to the `test_data` dictionary and save it as a `"instruction-data-with-response.json"` file for record-keeping so that we can load and analyze it in separate Python sessions if needed.

In [71]:
# instruction-data-with-response.json$
with open("instruction-data-with-response.json", 'r') as f:
    print(json.dumps(f.read().split("\n")[:30], indent=2))

[
  "[",
  "    {",
  "        \"instruction\": \"Rewrite the sentence using a simile.\",",
  "        \"input\": \"The car is very fast.\",",
  "        \"output\": \"The car is as fast as lightning.\",",
  "        \"model_response\": \"The car is as fast as a bullet.\"",
  "    },",
  "    {",
  "        \"instruction\": \"What type of cloud is typically associated with thunderstorms?\",",
  "        \"input\": \"\",",
  "        \"output\": \"The type of cloud typically associated with thunderstorms is cumulonimbus.\",",
  "        \"model_response\": \"The type of cloud associated with thunderstorms is a cumulus cloud.\"",
  "    },",
  "    {",
  "        \"instruction\": \"Name the author of 'Pride and Prejudice'.\",",
  "        \"input\": \"\",",
  "        \"output\": \"Jane Austen.\",",
  "        \"model_response\": \"The author of 'Pride and Prejudice' is Jane Austen.\"",
  "    },",
  "    {",
  "        \"instruction\": \"What is the periodic symbol for chlorine?\",",
  

### 9. Scoring the responses

In [84]:
import psutil

def check_if_running(process_name):
    running = False
    for proc in psutil.process_iter(["name"]):
        if process_name in proc.info["name"]:
            running = True
            break
    return running

ollama_running = check_if_running("ollama")

if not ollama_running:
    raise RuntimeError(
        "Ollama not running. Launch ollama before proceeding."
)
print("Ollama running:", check_if_running("ollama"))

Ollama running: True


---

In [83]:
# This cell is optional; it allows you to restart the notebook
# and only run section 9. without rerunning any of the previous code
import json
from tqdm import tqdm

file_path = "instruction-data-with-response.json"

with open(file_path, "r") as file:
    test_data = json.load(file)


def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""

    return instruction_text + input_text

---

In [85]:
import urllib.request

def query_model(prompt, model="llama3", url="http://localhost:11434/api/chat"):
    # Create the data payload as a dictionary
    data = {
        "model": model,
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "options": {     # Settings below are required for deterministic responses
            "seed": 123,
            "temperature": 0,
            "num_ctx": 2048
        }
    }


    # Convert the dictionary to a JSON formatted string and encode it to bytes
    payload = json.dumps(data).encode("utf-8")

    # Create a request object, setting the method to POST and adding necessary headers
    request = urllib.request.Request(
        url,
        data=payload,
        method="POST"
    )
    request.add_header("Content-Type", "application/json")

    # Send the request and capture the response
    response_data = ""
    with urllib.request.urlopen(request) as response:
        # Read and decode the response
        while True:
            line = response.readline().decode("utf-8")
            if not line:
                break
            response_json = json.loads(line)
            response_data += response_json["message"]["content"]

    return response_data

In [90]:
model = "llama3.2"
result = query_model("What do Llamas eat?", model)
print(result)

Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:

1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and grassy weeds.
2. Hay: High-quality hay, such as timothy hay or alfalfa hay, is a staple in a llama's diet. It provides essential nutrients like fiber, protein, and vitamins.
3. Grains: Llamas may also be fed grains like oats, barley, or corn, but these should not make up more than 10% of their diet.
4. Fruits and vegetables: Fresh fruits and vegetables, such as apples, carrots, and sweet potatoes, can be given to llamas as treats or added to their hay.
5. Browse: Llamas may also eat browse, which includes leaves, twigs, and other vegetation from trees and shrubs.

It's essential to note that llamas have a unique digestive system, with a four-chambered stomach, which allows them to break down and extract nutrients from plant material more efficiently than many other 

In [93]:
for entry in test_data[:3]:
    prompt = (
        f"Given the input `{format_input(entry)}` "
        f"and correct output `{entry['output']}`, "
        f"score the model response `{entry['model_response']}`"
        f" on a scale from 0 to 100, where 100 is the best score. "
    )
    print("\nDataset response:")
    print(">>", entry['output'])
    print("\nModel response:")
    print(">>", entry["model_response"])
    print("\nScore:")
    print(">>", query_model(prompt, model))
    print("\n-------------------------")


Dataset response:
>> The car is as fast as lightning.

Model response:
>> The car is as fast as a bullet.

Score:
>> To rewrite the sentence using a simile, we need to compare the speed of the car to something else.

Correct output: The car is as fast as lightning.

Score: 100

The model response "The car is as fast as a bullet" is close, but not perfect. A simile should use "like" or "as" to make the comparison, whereas "as fast as a bullet" implies that the car is literally a bullet, which isn't the intended meaning.

A better score for the model response would be around 80-90, as it's close to the correct form but not quite there.

-------------------------

Dataset response:
>> The type of cloud typically associated with thunderstorms is cumulonimbus.

Model response:
>> The type of cloud associated with thunderstorms is a cumulus cloud.

Score:
>> I would rate the model response a 20.

The reason for this low score is that the model response contains an error in its classificatio

In [88]:
def generate_model_scores(json_data, json_key, model="llama3.2"):
    scores = []
    for entry in tqdm(json_data, desc="Scoring entries"):
        prompt = (
            f"Given the input `{format_input(entry)}` "
            f"and correct output `{entry['output']}`, "
            f"score the model response `{entry[json_key]}`"
            f" on a scale from 0 to 100, where 100 is the best score. "
            f"Respond with the integer number only."   #1
        )
        score = query_model(prompt, model)
        try:
            scores.append(int(score))
        except ValueError:
            print(f"Could not convert score: {score}")
            continue

    return scores

In [96]:
# I will generate the scores on a more powerful GPU because mine is too weak ⚰
"""
scores = generate_model_scores(test_data, "model_response")
print(f"Number of scores: {len(scores)} of {len(test_data)}")
print(f"Average score: {sum(scores)/len(scores):.2f}\n")
""";
# Number of scores: 96 of 110
# Average score: 50.56

Finally, we get a metric to evaluate our model : `50.56`.

## Bonus: 2.4, replace mask instructions to see if it improves the model

<div class="alert alert-block alert-info">
<b>Reminder :</b>
Here we mask the padding token IDs with `-100` because it is a value that is ignored by the `cross_entropy` function!

Additionnally, we can use the same masking on the instructions token IDs.

<img src="https://camo.githubusercontent.com/e5061d1720dc8a56c789519562aa425a777971cce962c62a9ceab8b8c77fe627/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830375f636f6d707265737365642f6d61736b2d696e737472756374696f6e732e776562703f31" width=75%>
</div>

In [251]:
def custom_collate_fn_mask(batch, pad_token_id=50256, instruction_length=3, ignore_index=-100, allowed_max_length=None, device='cpu'):
    batch_max_length = max(len(item) + 1 for item in batch)
    inputs_list  = []
    targets_list = []

    instruction_indices = torch.arange(0, instruction_length - 1, dtype=torch.int32)
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])

        mask = targets == pad_token_id
        indices = torch.nonzero(mask).squeeze()
        
        if indices.numel() > 1:
            #indices = torch.cat([instruction_indices, indices[1:]]))
            print(indices[1:])
            print(torch.cat([instruction_indices, indices[1:]]))
            indices_2 = torch.cat([instruction_indices, indices[1:]])
            #targets[indices[1:]] = ignore_index
            targets[indices_2] = ignore_index
            
        if allowed_max_length is not None:
            inputs = inputs[:allowed_max_length]
            targets = targets[:allowed_max_length]
        
        inputs_list.append(inputs)
        targets_list.append(targets)
        
    inputs_tensor  = torch.stack(inputs_list).to(device)
    targets_tensor = torch.stack(targets_list).to(device)
    
    return inputs_tensor, targets_tensor

In [255]:
inputs_1 = [-1, -1, -1, 0, 1, 2, 3, 4]
inputs_2 = [-1, -1, -1, 5, 6]
inputs_3 = [-1, -1, -1, 7, 8, 9]
batch = (
    inputs_1,
    inputs_2,
    inputs_3
)
custom_collate_fn(batch)

tensor([5, 6, 7])
tensor([0, 1, 5, 6, 7])
tensor([6, 7])
tensor([0, 1, 6, 7])


(tensor([[   -1,    -1,    -1,     0,     1,     2,     3,     4],
         [   -1,    -1,    -1,     5,     6, 50256, 50256, 50256],
         [   -1,    -1,    -1,     7,     8,     9, 50256, 50256]]),
 tensor([[   -1,    -1,     0,     1,     2,     3,     4, 50256],
         [ -100,  -100,     5,     6, 50256,  -100,  -100,  -100],
         [ -100,  -100,     7,     8,     9, 50256,  -100,  -100]]))

---

---

In [252]:
customized_collate_fn_mask = partial(
    custom_collate_fn_mask,
    device=device,
    instruction_length=24,
    allowed_max_length=1024
)

Trained for 2 epochs, the loss is lower with the instructions masking :


<img src="plot2_mask.png">

And yet, evaluating the model with Llama 3.2 we get a lower score...
```
Number of scores: 95 of 110
Average score: 46.53