# How ChatGPT Works Part 1: Supervised Fine-Tuning


> Up to a point, overfitting the model to the SFT dataset can continue to increase human preference ratings.

## Collecting Prompts

Submit some prompts [here](https://docs.google.com/forms/d/e/1FAIpQLSfSBFzS4yrdUwy3DjEj2kskTc1JXk-T47TbmK8TaEgSt4fkcA/viewform?usp=sf_link).

## Collecting Ideal Prompt Responses

These days, there are more sophisticated ways to manage data collection.
One platform is [Scale AI](https://scale.com).

Submit your email [here](https://docs.google.com/forms/d/e/1FAIpQLScsMfW1Fh0bwget7cZKCzm6TQ-1c0AsvQFHtBain2l1mjnIcQ/viewform?usp=sf_link) so I can allow you to join my demo labelling workforce and you can access the Scale labelling platform.


In [6]:
import torch
import json

class SFTDataset(torch.utils.data.Dataset):
    """Supervised Fine-Tuning Dataset

    Returns:
        prompt: str
        response: str
    """
    def __init__(self):
        with open("raw_data.json") as f:
            self.data = json.load(f)
    
    def __len__(self):
        """Defines the length of the dataset."""
        return len(self.data)
    
    def __getitem__(self, idx):
        """Defines how to get a sample from the dataset by indexing it.

        Returns:
            prompt: str
            response: str
        """
        return self.data[idx]["prompt"], self.data[idx]["response"]
    

dataset = SFTDataset()
print(dataset[0])

('The best reasons to study at Oxford university:\n- Full of culture\n- Attracts great people\n- Excellent staff', '- Every college has its own bar and drink')


Unfortunately, GPT-3.5 is not currently available to download. 
The model parameters are closed source and owned by OpenAI.
So instead, we'll work with GPT-2 - a smaller predecessor of the model trained on the same task of language modelling. 

[GPT-2](https://huggingface.co/gpt2) (and original [GPT](https://huggingface.co/openai-gpt)) model are available through HuggingFace.

Here are the key differences:
- Size:
    - GPT has 117M parameters
    - GPT-2 has 1.5B parameters
    - GPT-3 has 175B parameters (800GB storage required)
    
- Training data size: 
    <!-- - GPT: -->
    - GPT-2: 40GB (8M webpages)
    - GPT-3: 45TB
<!-- - Training data variety:
The GPT models were trained on increasingly larger and more diverse datasets, with GPT-3 trained on a massive corpus of web pages, books, and other text sources. -->
- Task performance: GPT-3 has demonstrated better performance on a wide range of natural language tasks, including question answering, language translation, and natural language generation. It has also shown an ability to perform some common sense reasoning and to generate coherent and informative responses even to complex prompts.

- Speed and efficiency: Because of its size, GPT-3 is slower and more resource-intensive to run than GPT-2 or the original GPT. However, it can generate high-quality outputs with fewer prompts or examples.

- Release date: The original GPT was released in 2018, GPT-2 in 2019, and GPT-3 in 2020. Each new model represents a significant advance in natural language processing capabilities.

Let's load in GPT-2 and make sure we're comfortable with how it can be used to generate new text:

In [17]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)


tokenizer = GPT2Tokenizer.from_pretrained("gpt2", config=configuration) # Load the tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2") # Load the model

# generate a sequence of tokens using the model's forward method
prompt = "Hello, I am a language model."
input_ids = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(input_ids, max_length=50, do_sample=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, I am a language model.

To understand our problems, it must be explained how most languages do it.

Language Models

So how do we know languages are valid?

Languages are all about the relationships between


Now let's implement a function to initiate a back and forth chat with the model.

In [None]:
def chat():
    """Chat with the model."""
    prompt = ""
    while True:
        # GET USER INPUT
        next_input = "You: " + input("You: ") + "\nBot: "
        print(next_input)
        prompt += next_input

        # GENERATE A SEQUENCE OF TOKENS USING THE MODEL'S FORWARD METHOD
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        output = model.generate(input_ids, max_length=100, do_sample=True, top_k=50, top_p=0.95, temperature=0.7)

        # PRINT THE RESPONSE AND UPDATE THE PROMPT
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        print(response)
        prompt += response

chat()


In [34]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# print(model.modules)
# scsdc

def train(epochs=10):
    # Create the dataset and dataloader
    dataset = SFTDataset()
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Create the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.95)) # as used in the InstructGPT paper

    # Set up logging
    writer = SummaryWriter() # for logging our loss to TensorBoard
    batch_idx = 0 # for setting the x-axis of our TensorBoard plots (loss vs. batch index)

    # Train the model
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}")
        for batch in tqdm(dataloader):
            # Get the data
            prompt, response = batch
            prompt = prompt[0]
            response = response[0]

            # Encode the data
            entire_text = prompt + response
            context_dict = tokenizer(
                '<|startoftext|>' + entire_text + '<|endoftext|>',
                                    #    truncation=True, 
                                    #    max_length=max_length, 
                                    #    padding="max_length"
            )

            # input_ids = tokenizer.encode(prompt, return_tensors="pt")
            # print(input_ids)
            # print(tokenizer.decode(input_ids[0], skip_special_tokens=False))
            # print(len(input_ids))
            # labels = tokenizer.encode(response, return_tensors="pt")
            # print(tokenizer.decode(labels[0], skip_special_tokens=False))
            # print(len(labels))
            # sdcs

            input_ids = torch.tensor(context_dict.input_ids)
            labels = torch.tensor(context_dict.input_ids)
            attention_mask = torch.tensor(context_dict.attention_mask)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                labels=labels,
                attention_mask=attention_mask,
            )

            # logits = outputs.logits


            loss = outputs.loss

            # Backward pass
            loss.backward()
            optimizer.step()

            # Zero the gradients
            optimizer.zero_grad()

            # Log the loss
            print(f"Loss: {loss.item()}", batch_idx)
            writer.add_scalar("Loss/train", loss.item(), batch_idx)
            batch_idx += 1

train()


Epoch 1


 25%|██▌       | 1/4 [00:02<00:06,  2.20s/it]

Loss: 0.04907530918717384 0


 50%|█████     | 2/4 [00:03<00:03,  1.94s/it]

Loss: 0.12487148493528366 1


 75%|███████▌  | 3/4 [00:05<00:01,  1.71s/it]

Loss: 0.2410813570022583 2


100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Loss: 0.6012824177742004 3
Epoch 2


 25%|██▌       | 1/4 [00:01<00:04,  1.45s/it]

Loss: 0.044717010110616684 4


 50%|█████     | 2/4 [00:02<00:03,  1.51s/it]

Loss: 0.03734680637717247 5


 75%|███████▌  | 3/4 [00:04<00:01,  1.44s/it]

Loss: 0.08754557371139526 6


100%|██████████| 4/4 [00:06<00:00,  1.61s/it]


Loss: 0.44542935490608215 7
Epoch 3


 25%|██▌       | 1/4 [00:01<00:04,  1.62s/it]

Loss: 0.037838295102119446 8


 50%|█████     | 2/4 [00:03<00:03,  1.54s/it]

Loss: 0.08416501432657242 9


 75%|███████▌  | 3/4 [00:05<00:01,  1.82s/it]

Loss: 0.3369210362434387 10


100%|██████████| 4/4 [00:06<00:00,  1.67s/it]


Loss: 0.05517173931002617 11
Epoch 4


 25%|██▌       | 1/4 [00:01<00:04,  1.53s/it]

Loss: 0.08211680501699448 12


 50%|█████     | 2/4 [00:03<00:03,  1.95s/it]

Loss: 0.23072893917560577 13


 75%|███████▌  | 3/4 [00:05<00:01,  1.77s/it]

Loss: 0.04853944852948189 14


100%|██████████| 4/4 [00:07<00:00,  1.76s/it]


Loss: 0.041240278631448746 15
Epoch 5


 25%|██▌       | 1/4 [00:02<00:07,  2.40s/it]

Loss: 0.14014741778373718 16


 50%|█████     | 2/4 [00:04<00:04,  2.09s/it]

Loss: 0.05555743724107742 17


 75%|███████▌  | 3/4 [00:06<00:01,  1.94s/it]

Loss: 0.07117579877376556 18


100%|██████████| 4/4 [00:07<00:00,  1.94s/it]


Loss: 0.03610841929912567 19
Epoch 6


 25%|██▌       | 1/4 [00:01<00:05,  1.82s/it]

Loss: 0.05602844059467316 20


 50%|█████     | 2/4 [00:03<00:03,  1.75s/it]

Loss: 0.03551921993494034 21


 75%|███████▌  | 3/4 [00:05<00:01,  1.99s/it]

Loss: 0.0998191386461258 22


100%|██████████| 4/4 [00:07<00:00,  1.88s/it]


Loss: 0.07416810840368271 23
Epoch 7


 25%|██▌       | 1/4 [00:02<00:07,  2.48s/it]

Loss: 0.05246926471590996 24


 50%|█████     | 2/4 [00:04<00:04,  2.02s/it]

Loss: 0.036311373114585876 25


 75%|███████▌  | 3/4 [00:05<00:01,  1.90s/it]

Loss: 0.06711934506893158 26


100%|██████████| 4/4 [00:07<00:00,  1.95s/it]


Loss: 0.061407703906297684 27
Epoch 8


 25%|██▌       | 1/4 [00:01<00:05,  1.78s/it]

Loss: 0.05729812756180763 28


 50%|█████     | 2/4 [00:03<00:03,  1.64s/it]

Loss: 0.05340489372611046 29


 75%|███████▌  | 3/4 [00:05<00:01,  1.95s/it]

Loss: 0.055069807916879654 30


100%|██████████| 4/4 [00:07<00:00,  1.96s/it]


Loss: 0.10297903418540955 31
Epoch 9


 25%|██▌       | 1/4 [00:01<00:04,  1.57s/it]

Loss: 0.04357868432998657 32


 50%|█████     | 2/4 [00:02<00:02,  1.46s/it]

Loss: 0.06658864766359329 33


 75%|███████▌  | 3/4 [00:04<00:01,  1.44s/it]

Loss: 0.048337750136852264 34


100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Loss: 0.054456938058137894 35
Epoch 10


 25%|██▌       | 1/4 [00:02<00:06,  2.04s/it]

Loss: 0.04976287856698036 36


 50%|█████     | 2/4 [00:03<00:03,  1.75s/it]

Loss: 0.03473042696714401 37


 75%|███████▌  | 3/4 [00:05<00:01,  1.61s/it]

Loss: 0.043598826974630356 38


100%|██████████| 4/4 [00:06<00:00,  1.60s/it]

Loss: 0.07417843490839005 39



