# Fine tuning the model to make a chat bot

This is the big guacamole at the end of the rainbow. We'll be fine tuning one of the OpenAI models to be able to respond sort of like ChatGPT. I think there's an example of trying to do this on the foundation model in `openai.ipynb` without fine-tuning, and right now it _sucks_.

In [2]:
import import_ipynb
import openai # type:ignore
import gpt # type:ignore
import torch
import urllib
import ssl
import os
import json
from pprint import pprint
from typing import TypedDict
from torch.utils.data import Dataset, DataLoader
import tiktoken

def get_device() -> torch.device:
    if torch.cuda.is_available(): # type: ignore[attr-defined]
        return torch.device("cuda")
    elif torch.backends.mps.is_available(): # type: ignore[attr-defined]
        return torch.device("mps:0")
    else:
        return torch.device("cpu")

## Download the instruction training data

This is 1,100 instruction-response pairs (actually some have a third field called input) that were made specifically for the book.

In [None]:
class InstructionExample(TypedDict):
    instruction: str  # A description of the task to be performed
    input: str        # Optional parameter for the task
    output: str       # The expected result of performing the task

def download_and_load_file(file_path: str, url: str) -> list[InstructionExample]:
    ssl_context = ssl.create_default_context()
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE

    if not os.path.exists(file_path):
        with urllib.request.urlopen(url, context=ssl_context) as response: # type:ignore
            text_data = response.read().decode("utf-8")
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)
    else:
        with open(file_path, "r", encoding="utf-8") as file:
            text_data = file.read()

    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    
    return data

file_path = "instruction-data.json"
url = (
    "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch"
    "/main/ch07/01_main-chapter-code/instruction-data.json"
)

data = download_and_load_file(file_path, url)
print("Number of entries:", len(data))
print("Example:")
pprint(data[1])

Number of entries: 1100
Example:
{'input': 'He go to the park every day.',
 'instruction': 'Edit the following sentence for grammar.',
 'output': 'He goes to the park every day.'}


## Convert the examples to Stanford Alpaca format

The [format](https://github.com/tatsu-lab/stanford_alpaca) looks like this:

```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
```

Or, if there's no input:

```
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
```

In [None]:
def format_input(entry: InstructionExample, include_response:bool=True) -> str:
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )

    input_text = f"\n\n### Input:\n{entry["input"]}" if entry["input"] else ""
    response_text = f"\n\n### Response:\n{entry["output"]}" if include_response else ""

    return instruction_text + input_text + response_text

print(format_input(data[1]))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Edit the following sentence for grammar.

### Input:
He go to the park every day.

### Response:
He goes to the park every day.


In [None]:
class InstructionDataset(Dataset):
    def __init__(self, data: DataLoader, tokenizer: tiktoken.Encoding):
        self.data = data

        # Pre-tokenize texts
        self.encoded_texts = []
        for entry in data:
            full_text = format_input(entry)
            self.encoded_texts.append(
                tokenizer.encode(full_text)
            )
    
    def __getitem__(self, index) -> list[int]:
        return self.encoded_texts[index]
    
    def __len__(self):
        return len(self.data)


## Custom collate function

Passing in a custom collate function lets us easily pad out shorter sequences in each batch to match the longest one.
Initially, the padding token will be `<|endoftext|>`, but we'll eventually set it up so that there's only one EOT token
and the padding will be done with `-100`.

The collate function is responsible for:
1. Finding the longest sequence in the batch
2. Padding and preparing inputs
3. Removing the extra EOT tokens
4. Converting the token list to a tensor and transferring it to the target device.


### We're not masking the instructions

We could use `-100` to mask out the instructions from each example. That would avoid rewarding the model for memorizing
worthless bits like "Below is a task…", and some people think that's helpful. But it's controversial, and there's at least
one paper, ["Instruction Tuning with Loss Over Instructions,"](https://arxiv.org/abs/2405.14394) that argues that it's
better to train on the whole thing.

Maybe I'll try adding instruction masking later, but for now it's not recommended.

In [None]:
def custom_collate_fn(
        batch: list[list[int]],
        pad_token_id: int=50256, # i.e., <|endoftext|>
        ignore_index: int=-100, # this is the default ignore index for torch.nn.CrossEntropyLoss
        allowed_max_length: int|None=None,
        device: str|torch.device="cpu"
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_max_length = max([len(item)+1 for item in batch])

    inputs_lst, targets_lst = [], []

    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])

        mask = targets == pad_token_id # tensor([bool * max_length])
        indices = torch.nonzero(mask).squeeze()
        if indices.numel() > 1:
            # Note: we only do this -100 thing in the targets tensor
            targets[indices[1:]] = ignore_index

        if allowed_max_length is not None:
            inputs = inputs[:allowed_max_length]
            targets = targets[:allowed_max_length]
        
        inputs_lst.append(inputs)
        targets_lst.append(targets)

    inputs_tensor = torch.stack(inputs_lst).to(device)
    targets_tensor = torch.stack(targets_lst).to(device)

    return inputs_tensor, targets_tensor


In [5]:
batch = [ 
    [0, 1, 2, 3, 4],
    [5, 6],
    [7, 8, 9]
 ]

inputs, targets = custom_collate_fn(batch)
print(inputs)
print(targets)

tensor([[    0,     1,     2,     3,     4],
        [    5,     6, 50256, 50256, 50256],
        [    7,     8,     9, 50256, 50256]])
tensor([[    1,     2,     3,     4, 50256],
        [    6, 50256,  -100,  -100,  -100],
        [    8,     9, 50256,  -100,  -100]])
