## *Textbooks Are All You Need* Notebook Presentation

#### By Jack Sun and Quentin Clark, University of Toronto
#### For CS 2541: Large Models



This notebook is meant to explain the main ideas behind the paper *Textbooks Are All You Need* ([see here](https://openreview.net/forum?id=Fq8tKtjACC)) with a toy example.

The main thesis of this paper is that, in sequential generation tasks (i.e., for Large Language Modelling) sometimes using *higher-quality data* with a *smaller model* can lead to superior performance for downstream tasks.

This notebook has two sections: one that lets you play around with the final phi-1 model from Microsoft Research, and another that walks you through a toy example with a mini-GPT model on a synthetic math dataset.


In [None]:
# import traceback
# import numpy as np
# import jax
# import jax.numpy as jnp
# try:
#   from penzai import pz
# except ImportError:
#   !pip install penzai[notebook]
#   from penzai import pz
# # Install necessary libraries
# !pip install transformers torch --quiet
# import treescope
# from penzai.core.named_axes import wrap, nmap
# from tqdm.notebook import tqdm
# import matplotlib.pyplot as plt
# import humanize
# treescope.basic_interactive_setup(autovisualize_arrays=True)

In [None]:


# Import required modules
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load the model and tokenizer
model_name = "microsoft/phi-1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")



Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/734 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.84G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

In [None]:
function_text = """def quicksort(lst: list):
    \"""
    This takes a list of integers and sorts the list in ascending order.
    \"""
"""

print(function_text)

# Test the model
inputs = tokenizer(function_text, return_tensors="pt")

# Generate output
output = model.generate(**inputs, max_length=200)

# Decode and print the output
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Input:", function_text)
print("Output:", decoded_output)

def quicksort(lst: list):
    """
    This takes a list of integers and sorts the list in ascending order.
    """

Input: def quicksort(lst: list):
    """
    This takes a list of integers and sorts the list in ascending order.
    """

Output: def quicksort(lst: list):
    """
    This takes a list of integers and sorts the list in ascending order.
    """
    if len(lst) <= 1:
        return lst
    else:
        pivot = lst[0]
        left = []
        right = []
        for num in lst[1:]:
            if num < pivot:
                left.append(num)
            else:
                right.append(num)
        return quicksort(left) + [pivot] + quicksort(right)



from typing import List

def find_smallest_multiple_of_list(li: List[int]) -> int:
    """
    Returns the smallest positive integer that is divisible by all the numbers in the input list.

    Args:
    li (List[int]): A list


In [None]:
def quicksort(lst: list):
    """
    This takes a list of integers and sorts the list in ascending order.
    """
    if len(lst) <= 1:
        return lst
    else:
        pivot = lst[0]
        left = []
        right = []
        for i in range(1, len(lst)):
            if lst[i] < pivot:
                left.append(lst[i])
            else:
                right.append(lst[i])
        return quicksort(left) + [pivot] + quicksort(right)
lst = quicksort([3,2,1])
print(lst)

[1, 2, 3]


In [None]:
function_text = """def sort_concat_square_deduplicate(list1, list2, my_threshold):
    \"""
    This function takes two lists of integers, sorts each of them in ascending order,
    concatenates them, squares the entries at even indices, filters out entries
    smaller than my_threshold and then removes duplicates. The resulting list is
    returned.
    \"""
"""
print(function_text)


def sort_concat_square_deduplicate(list1, list2, my_threshold):
    """
    This function takes two lists of integers, sorts each of them in ascending order,
    concatenates them, squares the entries at even indices, filters out entries
    smaller than my_threshold and then removes duplicates. The resulting list is
    returned.
    """



In [None]:
# Test the model
inputs = tokenizer(function_text, return_tensors="pt")

# Generate output
output = model.generate(**inputs, max_length=200)

# Decode and print the output
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Input:", function_text)
print("Output:", decoded_output)

Input: def sort_concat_square_deduplicate(list1, list2, my_threshold):
    """
    This function takes two lists of integers, sorts each of them in ascending order,
    concatenates them, squares the entries at even indices, filters out entries
    smaller than my_threshold and then removes duplicates. The resulting list is
    returned.
    """

Output: def sort_concat_square_deduplicate(list1, list2, my_threshold):
    """
    This function takes two lists of integers, sorts each of them in ascending order,
    concatenates them, squares the entries at even indices, filters out entries
    smaller than my_threshold and then removes duplicates. The resulting list is
    returned.
    """
    # Sort and concatenate the two lists
    combined_list = sorted(list1 + list2)
    
    # Square the entries at even indices
    squared_list = [num**2 for i, num in enumerate(combined_list) if i % 2 == 0]
    
    # Filter out entries smaller than my_threshold and remove duplicates
    filtered_l

In [None]:
# phi-1-generated code
def sort_concat_square_deduplicate(list1, list2, my_threshold):
    """
    This function takes two lists of integers, sorts each of them in ascending order,
    concatenates them, squares the entries at even indices, filters out entries
    smaller than my_threshold and then removes duplicates. The resulting list is
    returned.
    """
    # Sort and concatenate the two lists
    combined_list = sorted(list1 + list2)

    # Square the entries at even indices
    squared_list = [num**2 for i, num in enumerate(combined_list) if i % 2 == 0]

    # Filter out entries smaller than my_threshold and remove duplicates
    filtered_list = list(set(squared_list))

    return filtered_list

# Almost correct, but failed to filter out entries smaller than my_threshold
test_list = [1,2,3,4,5]
test_list2 = [6,7,8,9,10]
result = sort_concat_square_deduplicate(test_list, test_list2, 5)
print(result)

[1, 9, 81, 49, 25]


In [None]:
# HumanEval:
function_text = """You are given a non-empty list of positive
integers. Return the greatest integer that
is greater than zero, and has a frequency
greater than or equal to the value of the
integer itself. The frequency of an integer
is the number of times it appears in the list."""

In [None]:
# Test the model
inputs = tokenizer(function_text, return_tensors="pt")

# Generate output
output = model.generate(**inputs, max_length=400)

# Decode and print the output
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Input:", function_text)
print("Output:", decoded_output)

Input: You are given a non-empty list of positive
integers. Return the greatest integer that
is greater than zero, and has a frequency
greater than or equal to the value of the
integer itself. The frequency of an integer
is the number of times it appears in the list.
Output: You are given a non-empty list of positive
integers. Return the greatest integer that
is greater than zero, and has a frequency
greater than or equal to the value of the
integer itself. The frequency of an integer
is the number of times it appears in the list.

Example:
greatest_frequency_above_zero([1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) -> 4
greatest_frequency_above_zero([0, 0, 0, 0, 0]) -> 0
greatest_frequency_above_zero([1, 1, 1, 1, 1]) -> 1
"""

from typing import List

def greatest_frequency_above_zero(li: List[int]) -> int:
    """
    Returns the greatest integer that is greater than zero, and has a frequency
    greater than or equal to the value of the integer itself.
    
    Args:
    li: A list of positive int

In [None]:
from typing import List

def greatest_frequency_above_zero(li: List[int]) -> int:
    """
    Returns the greatest integer that is greater than zero, and has a frequency
    greater than or equal to the value of the integer itself.

    Args:
    li: A list of positive integers

    Returns:
    The greatest integer that is greater than zero, and has a frequency greater
    than or equal to the value of the integer itself. If no such integer exists,
    returns 0.
    """
    freq = {}
    for num in li:
        if num > 0:
            freq[num] = freq.get(num, 0) + 1
    max_freq = 0
    max_num = 0
    for num, count in freq.items():
        if num >= count and num > max_num:
            max_freq = count
            max_num = num
    return max_num
print(greatest_frequency_above_zero([1, 2, 2, 3, 3, 3, 4, 4, 4, 4,5])) # Cannot filter out the number with frequence smaller than it self. Should return 4 instead.
print(greatest_frequency_above_zero([0, 0, 0, 0, 0]))

5


0

## Section 2 - Math-GPT

To demonstrate the main result of the *Texbooks* paper in a toy environment, we will construct our own small GPT model trained to perform basic arithmetic.

We will make two models - a larger one trained on a larger amount of low-quality data (many of the examples will be incorrect) and a smaller one trained on a smaller amount of high-quality data (all of the examples will be correct).


In [None]:
 ## Basic imports
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


Here, we define our dataset, which can either be high quality (have no mistakes, smaller size) or low quality (have mistakes ~1/3rd of the time, larger size.)

In [None]:
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    """
    Define the addition dataset
    """

    def __init__(self, ndigit, split, quality = 'High'):
        self.quality = quality
        self.split = split # train/test
        self.ndigit = ndigit
        self.vocab_size = 10 # 10 possible digits 0..9
        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back
        self.block_size = ndigit + ndigit + ndigit + 1 - 1

        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
        if quality == 'High' and split != 'test': # makes dataset smaller
          self.len = int(self.ixes.size/5)
          self.ixes = self.ixes[:self.len]
        elif quality == 'Low' or split == 'test': # makes dataset larger
          self.len = self.ixes.size
        else:
          print('ERROR: Quality Not Defined!')
    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        nd = 10**self.ndigit
        a = idx // nd
        b = idx %  nd
        c = a + b
        # if low quality, randomly introduce a mistake to the answer
        if self.quality == 'Low' and self.split == 'train':
          r = np.random.RandomState(idx)
          if np.random.RandomState(idx).rand() < 0.33: # third of the time give a wrong answer by adding a random value from [-5,5] to the answer
            c = c + np.random.RandomState(idx).randint(-5,5)
            c = max(0,c)
            #print(c)


        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028"
        dix = [int(s) for s in render] # convert each character to its token index
        # x will be input to GPT and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
        y[:self.ndigit*2-1] = -100
        return x, y

In [None]:
# create a dataset for e.g. 2-digit addition
ndigit = 2
train_dataset_high_quality = AdditionDataset(ndigit=ndigit, split='train', quality = 'High')
train_dataset_low_quality = AdditionDataset(ndigit=ndigit, split='train', quality = 'Low')
test_dataset = AdditionDataset(ndigit=ndigit, split='test')

print('High-Quality Training Set Size:',train_dataset_high_quality.len)
print('Low-Quality Training Set Size:',train_dataset_low_quality.len)
print('Test Set Size:',test_dataset.len)

High-Quality Training Set Size: 1800
Low-Quality Training Set Size: 9000
Test Set Size: 1000


As promised, our high-quality dataset is a fifth the size of our low quality one. We also defined a holdout dataset.

The next cell is hidden for brevity, but it defines our Transformer architecture as a pretty standard GPT-2 style decoder-only model with multihead attention.

In [None]:
# @title Architectue Definiton (unimportant)
class GPT(nn.Module):
    """  the full GPT language model, with a squence size of block_size """

    def __init__(self, vocab_size, n_embd, n_head, block_size, n_layer, embd_pdrop=0.1, attn_pdrop=0.1,resid_pdrop=0.1):
        super().__init__()

        # input embedding stem
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        """YOUR CODE HERE"""
        self.blocks = nn.Sequential(*[TransformerBlock(n_embd, n_head, block_size, attn_pdrop, resid_pdrop)
                                      for _ in range(n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

        self.block_size = block_size
        self.apply(self._init_weights)

        logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))


    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        You don't need to change this function. This is setting specific parameters for optimization.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, x, targets=None):
        b, t = x.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
        """YOUR CODE HERE"""

        # forward the GPT model
        token_embeddings=self.tok_emb(x)
        position_embeddings=self.pos_emb[:,:t,:]
        x=self.drop(token_embeddings+position_embeddings)
        x=self.blocks(x)
        x=self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
          loss=F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1))

        return logits, loss
class TransformerBlock(nn.Module):
    """ an Transformer block """

    def __init__(self, n_embd, n_head, block_size, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadSelfAttention(n_embd, n_head, block_size, attn_pdrop, resid_pdrop)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x):
        """YOUR CODE HERE?"""
        x=x+self.attn(self.ln1(x))
        x=x+self.mlp(self.ln2(x))
        return x
class MultiHeadSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    You can also use torch.nn.MultiheadAttention to validate your implementation

    """

    def __init__(self, n_embd, n_head, block_size, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        #Define key, query, value projections for all heads
        """YOUR CODE HERE"""
        self.key = nn.Linear(n_embd,n_embd)
        self.query = nn.Linear(n_embd,n_embd)
        self.value = nn.Linear(n_embd,n_embd)
        # Dropout layers
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd,n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        #self.mask = torch.tril(torch.ones(block_size,block_size)).view(1,1,block_size,block_size)
        self.register_buffer("mask",torch.tril(torch.ones(block_size,block_size)).view(1,1,block_size,block_size))



    def forward(self, x, layer_past=None):
        B, T, C = x.size() # B = Batch
        """YOUR CODE HERE"""
        k=self.key(x).view(B,T,self.n_head,C//self.n_head).transpose(1,2)
        q=self.query(x).view(B,T,self.n_head,C//self.n_head).transpose(1,2)
        v=self.value(x).view(B,T,self.n_head,C//self.n_head).transpose(1,2)

        #
        att=(q @ k.transpose(-2,-1))*(1.0/math.sqrt(k.size(-1)))
        att=att.masked_fill(self.mask[:,:,:T,:T]==0,float('-inf'))
        att=F.softmax(att,dim=-1)
        att=self.attn_drop(att)
        y=att @ v
        y=y.transpose(1,2).contiguous().view(B,T,C)
        output=self.resid_drop(self.proj(y))

        return output
import math
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

Next we instantiate a large and small version of our GPT model, and define some training parameters.

In [None]:
logger = logging.getLogger(__name__)
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)
# initialize a baby GPT model
model_high_quality = GPT(vocab_size = train_dataset_high_quality.vocab_size, n_embd=128, n_head=4, block_size =  train_dataset_high_quality.block_size, n_layer=2)
model_low_quality = GPT(vocab_size = train_dataset_high_quality.vocab_size, n_embd=248, n_head=4, block_size =  train_dataset_high_quality.block_size, n_layer=4)

print('Model Size: High-Quality Model: ',sum(p.numel() for p in model_high_quality.parameters()))
print('Model Size:  Low-Quality Model:',sum(p.numel() for p in model_low_quality.parameters()))

Model Size: High-Quality Model:  400128
Model Size:  Low-Quality Model: 2972032


Like in the Textbooks paper, our model that we will give high-quality data is much smaller (~6 times) smaller than the large model.

Next, we define some boiler plate training code (we've also hidden this for ease of reading). It is standard stuff you will see in any PyTorch training loop. One note - we increase the number of epochs for the high-quality model, so it is seeing roughly the same number of gradient updates/wall-clock trianing time as the low-quality model. We need to do this because the high-quality model has less data, so the same number of epochs will lead to less gradient steps.

In [None]:
# @title Training Loop
config_high_quality = TrainerConfig(max_epochs=250, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset_high_quality)*(ndigit+1),
                      num_workers=4)

config_low_quality = TrainerConfig(max_epochs=50, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset_low_quality)*(ndigit+1),
                      num_workers=4)
from tqdm import trange
def train_gpt(config,model,dataset):
  device = 'cpu'
  if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model.to(device)
  optimizer = model.configure_optimizers(config)



  tokens = 0
  for epoch in trange(config.max_epochs):
      model.train()
      data = dataset
      loader = DataLoader(data, shuffle=True, pin_memory=True,
                          batch_size=config.batch_size,
                          num_workers=config.num_workers)
      losses = []
      #pbar = tqdm(enumerate(loader), total=len(loader))
      for iter, (x, y) in enumerate(loader):
          # place data on the correct device
          x = x.to(device)
          y = y.to(device)
          # forward the model
          logits, loss = model(x, y)
          loss = loss.mean()
          model.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
          optimizer.step()
          # decay the learning rate based on our progress
          if config.lr_decay:
              tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
              if tokens < config.warmup_tokens:
                  # linear warmup
                  lr_mult = float(tokens) / float(max(1, config.warmup_tokens))
              else:
                  # cosine learning rate decay
                  progress = float(tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                  lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
              lr = config.learning_rate * lr_mult
              for param_group in optimizer.param_groups:
                  param_group['lr'] = lr
          else:
              lr = config.learning_rate
          # report progress
          #pbar.set_description(f"epoch {epoch+1} iter {iter}: train loss {loss.item():.5f}. lr {lr:e}")



Finally, we can train our model. This will take a little while, so be patient!

In [None]:
train_gpt(config_high_quality,model_high_quality,train_dataset_high_quality)
train_gpt(config_low_quality,model_low_quality,train_dataset_low_quality)

100%|██████████| 250/250 [01:06<00:00,  3.75it/s]
100%|██████████| 50/50 [04:21<00:00,  5.24s/it]


Almost done! Now it is time to see how our models perform on our smaller test datset. The next cell defines some basic sampling utilities. For the final time, we condense them for brevity.

In [None]:
# @title Sampling Utilities
def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

def sample(train_dataset,model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time.
    """
    block_size = train_dataset.block_size
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
        logits, _ = model(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)
    return x
def Addition_GPT(model,dataset, batch_size=32, max_batches=-1):
    device = 'cpu'
    if torch.cuda.is_available():
      device = torch.cuda.current_device()
    results = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        x = x.to(device)
        d1d2 = x[:, :ndigit*2]
        d1d2d3 = sample(dataset,model, d1d2, ndigit+1)
        d3 = d1d2d3[:, -(ndigit+1):]
        factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(device)
        # decode the integers from individual digits
        d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
        d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
        d3i_pred = (d3 * factors).sum(1)
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            judge = 'CORRECT' if correct[i] else 'WRONG'
            if not correct[i]:
                #print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)"
                      #% (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))
                 meow = 5

        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

Finally, it is time to see the results and run each of our models on the test dataset:

In [None]:
print('High-Quality Model Results:')
Addition_GPT(model_high_quality,test_dataset, batch_size=1024, max_batches=10)
print('Low-Quality Model Results:')
Addition_GPT(model_low_quality,test_dataset, batch_size=1024, max_batches=10)

High-Quality Model Results:
final score: 998/1000 = 99.80% correct
Low-Quality Model Results:
final score: 563/1000 = 56.30% correct


As expected, our model trained on a larger dataset with a significant amount of mistakes performs much worse than our model trained with a smaller, curated set of examples!