# Training Small Language Models on TinyStories dataset

TinyStories is a dataset with constrained vocabulary and training information, making it a valuable resource for researchers who are interested in studying how language models work and perform tasks such as writing coherent English text, in-context reasoning, and creative writing. The dataset can be used to train various small language models, which require little resources and time to do full pre-training yet produce fluent and coherent English utterances. 

TinyStories is made up of short stories written in words that typical 3 to 4-year-olds would usually understand in a style of fairy tales; it was generated by GPT-3.5 and GPT-4. The dataset consists of 2.1 million short stories, each of which is only a few sentences long. The stories are all about everyday events that a young child might experience, such as going to the park, playing with friends, or learning something new.

IPU is a great platform to experiment with TinyStories dataset and research novel techniques and architectures, due IPU's acceleration capability for small models. IPU relies of very fast SRAM memory and works at its best when the full model can be fit into a single M-2000 appliance consisting of 4 chips, connected by fast IPU Links. The smallest model takes as little as 5 minutes to become capable of generating coherent English sentences. 

Sample outputs:

- TinyStories Model with 5M params (5 mins training - 4IPUs): **This was a sunny day** of the beach. One day, a little girl named Amy was playing in the park. She saw a big tree with lots of leaves. She wanted to pick it up. She picked it up and put it in her pocket. Amy picked up the tree and started to eat it. She was so happy to have a friend to play with. Amy was so happy to have a new friend to play with.


- TinyStories Model with 50M params (30 mins training - 4 IPUs): **This was a sunny day** and the little girl was feeling very happy. The girl went to the park and she saw a big tree. She wanted to climb the tree and see what was up there. She started to climb the tree and it was very tall. The little girl kept climbing and climbing until she was very high up. She looked around and saw lots of fun things. She saw a big tree, a pond and a playground. The little girl was having so much fun that she didn't want to leave. But then she heard a voice calling her. It was her mom. The little girl was so happy that she had climbed the tree. She ran to her mom and hugged her. The mom smiled and hugged her back. She was so proud of her little girl. The little girl and her mom went home.


- TinyStories Model with 200M params (4h training - 4 IPUs): **This was a sunny day** and there were many kids playing. A little girl named Lily saw a boy named Tim. She walked up to him and said, "Hi, I'm Lily. Do you want to play with me?" Tim smiled and said, "Yes! Let's play together!" They played on the swings, the slide, and the seesaw. They had so much fun. After playing, Lily and Tim were tired. They sat down on the grass to relax. They talked and laughed together. The sun was warm and the grass was soft. They became good friends and promised to play together again soon.


- Off-the-shelf GPT-2 XL 1.5B params: **This was a sunny day** in downtown New York. The sidewalks were packed with people on sidewalks and in front of stores and businesses. There weren't a lot of cars running from building to building at this time of night. The nightlife was in full swing. 


Arxiv link to paper: [TinyStories: How Small Can Language Models Be and Still Speak Coherent English?](https://arxiv.org/abs/2305.07759).

## Setup

In [None]:
# Dowload libraries
! pip install datasets matplotlib einops tokenizers wandb pandas

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import json
import os
import sys
import math
import dataclasses
from IPython import display
from dataclasses import dataclass
from itertools import islice
from typing import *

import torch
import poptorch
from torch import nn, Tensor
import torch.nn.functional as F
import einops

import datasets
import tokenizers

import tqdm
import wandb
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

In [None]:
NUMBER_OF_IPUS = int(os.getenv("NUM_AVAILABLE_IPU", 4))
MODEL_SAVING = True

In [None]:
! gc-monitor

## Configs

In [None]:
@dataclass
class ModelConfig:
    hidden_size: int
    depth: int
    seq_length: int
    head_size: int
    dtype: Literal['half', 'float', 'float16', 'float32']
    pipeline_stages: int
    checkpointing: bool

@dataclass
class TrainingConfig:
    lr: float
    steps: int
    batch_size: int
    compute_batch_size: int
    replicas: int
    device_iterations: int
    generation_temperature: float
    offloading: bool
    wd: float = 0.01

@dataclass
class Experiment:
    name: str
    model: ModelConfig
    train: TrainingConfig
    profiling: bool
    wandb: bool

## Transformer Model definition

This section contains all the components that make the transformer model. In particular:

- attention block
- feed-forward block
- residual and layer-norm that transform the above components into modules
- a transformer layer consisting of two modules
- model consisting of embedding, layer stack and de-embedding projection

Note: this implementation does not use dropout and uses [ALiBi](https://arxiv.org/abs/2108.12409) positional encoding scheme.

In [None]:
# Transformer attention module with relative positional encoding
class Attention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.n_heads = config.hidden_size // config.head_size
        self.qkv = nn.Linear(config.hidden_size, 3 * self.n_heads * config.head_size)
        self.out = nn.Linear(self.n_heads * config.head_size, config.hidden_size)
        self.attention_bias = nn.Parameter(self.gen_attention_bias(), requires_grad=False)
        
    def gen_attention_bias(self) -> Tensor:
        causal_mask = torch.tril(torch.ones((self.config.seq_length, self.config.seq_length), dtype=torch.float16))
        causal_mask = causal_mask.view(1, 1, self.config.seq_length, self.config.seq_length)
        alibi_mask = self.gen_alibi_mask(causal_mask)
        causal_mask = (1.0 - causal_mask) * -10_000
        return alibi_mask + causal_mask
    
    # Based on https://nn.labml.ai/transformers/alibi/index.html
    def gen_alibi_mask(self, causal_mask: Tensor) -> Tensor:
        distances = causal_mask.to(torch.float32).cumsum(dim=-1)
        slopes = self.gen_slopes()
        return distances.to(torch.float16) * slopes.view(1, self.n_heads, 1, 1)
    
    def gen_slopes(self) -> Tensor:
        n = 2 ** math.floor(math.log2(self.n_heads))
        m_0 = 2.0 ** (-8.0 / n)
        m = torch.pow(m_0, torch.arange(1, 1 + n))
        if n < self.n_heads:
            m_hat_0 = 2.0 ** (-4.0 / n)
            m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (self.n_heads - n), 2))
            m = torch.cat([m, m_hat])
        return m

    def forward(self, x: Tensor) -> Tensor:
        s = x.shape[1]
        q, k, v = einops.rearrange(
            self.qkv(x), "b s (M n d) -> M b n s d", M=3, n=self.n_heads
        )
        a = torch.einsum("bnsd, bntd -> bnst", q, k) * q.shape[-1] ** -0.5
        a += self.attention_bias[:, :, :s, :s]
        mix = torch.einsum("bnst, bntd -> bnsd", torch.softmax(a, -1), v)
        return self.out(einops.rearrange(mix, "b n s d -> b s (n d)"))

In [None]:
# FFN module with a GeLU non-linearity
class FFN(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.up = nn.Linear(config.hidden_size, 4 * config.hidden_size)
        self.down = nn.Linear(self.up.out_features, self.up.in_features)

    def forward(self, x: Tensor) -> Tensor:
        return self.down(F.gelu(self.up(x)))

In [None]:
# Normalization wrapper with a Pre-Norm configuration
class PreNormResidual(nn.Module):
    def __init__(self, config: ModelConfig, body: nn.Module):
        super().__init__()
        self.norm = nn.LayerNorm([config.hidden_size])
        self.body = body

    def forward(self, x: Tensor) -> Tensor:
        return x + self.body(self.norm(x))

In [None]:
# Transformer layer
class TransformerLayer(nn.Sequential):
    def __init__(self, config: ModelConfig):
        super().__init__(
            PreNormResidual(config, Attention(config)),
            PreNormResidual(config, FFN(config)),
        )

In [None]:
# Full model
class Model(nn.Module):
    def __init__(self, config: ModelConfig, tokenizer: tokenizers.Tokenizer):
        super().__init__()
        self.config = c = config
        self.tokenizer = tokenizer

        self.pre = nn.Sequential(
            nn.Embedding(tokenizer.get_vocab_size(), c.hidden_size),
            nn.LayerNorm([c.hidden_size]),
        )
        self.core = nn.Sequential(*(TransformerLayer(config) for _ in range(c.depth)))
        self.post = nn.Sequential(
            nn.LayerNorm([c.hidden_size]),
            nn.Linear(c.hidden_size, tokenizer.get_vocab_size()),
        )
        self.pre[0] = poptorch.BeginBlock(self.pre[0], "Token embedding", ipu_id=0)
        for i in range(c.depth):
            if c.checkpointing:
                self.recomputation_checkpoint(self.core[i])
            self.core[i] = poptorch.BeginBlock(
                self.core[i], ipu_id=(i * c.pipeline_stages) // c.depth
            )
        self.post[1] = poptorch.BeginBlock(self.post[1], ipu_id=c.pipeline_stages - 1)
        self.model = nn.Sequential(self.pre, self.core, self.post)
        self.to(getattr(torch, c.dtype))
        
    @staticmethod
    def recomputation_checkpoint(module: nn.Module) -> torch.utils.hooks.RemovableHandle:
        """Annotates the output of a module to be checkpointed instead of
        recomputed."""
        def recompute_outputs(module, inputs, outputs):
            if isinstance(outputs, torch.Tensor):
                return poptorch.recomputationCheckpoint(outputs)
            elif isinstance(outputs, tuple):
                return tuple(poptorch.recomputationCheckpoint(y) for y in outputs)

        module.register_forward_hook(recompute_outputs)

    def forward(self, indices: Tensor, debug=False) -> Tensor:
        logits = self.model(indices).float()
        return F.cross_entropy(logits[:, :-1, :].flatten(0, -2), indices[:, 1:].flatten())

    def generate(self, prompt: str, count: int, temperature: float) -> str:
        prompt_ids = self.tokenizer.encode(prompt).ids
        completion_ids = []
        for _ in range(count):
            logits = self.model(torch.tensor(prompt_ids + completion_ids)[None])[0, -1]
            sample = torch.argmax(
                logits + temperature * torch.log(-torch.log(torch.rand_like(logits)))
            )
            print(self.tokenizer.decode(prompt_ids + completion_ids), end="\r")
            completion_ids.append(int(sample))
        print("\n")
        self.forward(torch.tensor(prompt_ids + completion_ids)[None], debug=True)
        return self.tokenizer.decode(completion_ids)

## Data Pipeline

In [None]:
@dataclass
class Dataset:
    data: Dict[str, Tensor]
    tokenizer: tokenizers.Tokenizer

    def batches(
        self, seq_length: int, batch_size: int, split: str
    ) -> Iterable[Tensor]:
        tokens = self.data[split]
        while True:
            offsets = torch.randint(
                0, len(tokens) - seq_length, size=(batch_size,)
            )
            yield torch.stack([tokens[i : i + seq_length].long() for i in offsets])

    @classmethod
    def create(cls, vocab_size: int, path: Path) -> "Dataset":
        if not (path / "tokenizer.json").exists() or not (path / "data.pt").exists():
            path.mkdir(exist_ok=True, parents=True)
            original_data = datasets.load_dataset("roneneldan/TinyStories")
            tokenizer = tokenizers.Tokenizer(
                tokenizers.models.BPE(end_of_word_suffix="</w>")
            )
            tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit()
            tokenizer.decoder = tokenizers.decoders.BPEDecoder(
                suffix=tokenizer.model.end_of_word_suffix
            )
            tokenizer.train_from_iterator(
                original_data["train"][: int(1e5)]["text"],
                tokenizers.trainers.BpeTrainer(
                    vocab_size=vocab_size,
                    limit_alphabet=512,
                    special_tokens=["<pad>", "</s>"],
                    end_of_word_suffix=tokenizer.model.end_of_word_suffix,
                    show_progress=False,
                ),
            )
            tokenizer.save(str(path / "tokenizer.json"))
            torch.save(
                {
                    name: torch.tensor(
                        [
                            token
                            for batch in tqdm.tqdm(
                                split.iter(1000),
                                total=len(split) // 1000,
                                desc=f"generating {name}",
                            )
                            for seq in tokenizer.encode_batch(batch["text"])
                            for token in [tokenizer.token_to_id("</s>")] + seq.ids
                        ],
                        dtype=torch.int16,
                    )
                    for name, split in original_data.items()
                },
                path / "data.pt",
            )
        return cls(
            data=torch.load(path / "data.pt"),
            tokenizer=tokenizers.Tokenizer.from_file(str(path / "tokenizer.json")),
        )

## Trainer 

In [None]:
class Trainer:
    
    def __init__(self, experiment: Experiment, dataset: Dataset):
        self.train_config = tc = experiment.train
        self.model_config = mc = experiment.model
        self.wandb = experiment.wandb
        self.name = experiment.name
        self.dataset = dataset
        
        if experiment.profiling:
            self.set_up_profiling(mc, tc)
        
        torch.seed()
        self.options = self.get_poptorch_options(mc, tc)
        
        
        self.model = Model(mc, dataset.tokenizer)
        self.optimizer = poptorch.optim.AdamW(
            self.model.parameters(), lr=tc.lr, weight_decay=tc.wd, loss_scaling=mc.seq_length, 
        )
        self.host_steps = tc.steps // tc.device_iterations
        self.lr_schedule = torch.optim.lr_scheduler.LinearLR(self.optimizer, 1, 0, total_iters=self.host_steps)
        self.poptorch_trainer = poptorch.trainingModel(self.model, self.options, self.optimizer)
        
        if self.wandb:
            self.set_up_wandb(experiment)
        
    @staticmethod
    def set_up_profiling(model_config: ModelConfig, train_config: TrainingConfig):
        """
        Enable memory profiling on IPU.
        """
        out = Path("profiles/latest")
        out.mkdir(exist_ok=True, parents=True)
        os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(
            {
                "autoReport.outputGraphProfile": True,
                "autoReport.directory": str(out),
                "autoReport.outputArchive": True,
                "debug.outputAllSymbols": True,
            }
        )
        (out / "app.json").write_text(
            json.dumps(dict(model=model_config.__dict__, training=train_config.__dict__))
        )     

    @staticmethod
    def get_poptorch_options(model_config: ModelConfig, train_config: TrainingConfig) -> poptorch.Options:
        """
        Configure poptorch options to run the model on IPU
        """
        options = poptorch.Options()
        options.output_mode = poptorch.OutputMode.Sum
        options.device_iterations = train_config.device_iterations
        options.replication_factor = train_config.replicas
        options.Training.gradient_accumulation = (
            train_config.batch_size // train_config.replicas // train_config.compute_batch_size
        )
        options.Precision.setPartialsType(torch.half)
        options.Precision.enableFloatingPointExceptions(True)
        
        if train_config.offloading:
            proportions = {f"IPU{ipu}": 0.4 for ipu in range(model_config.pipeline_stages)}
            options.setAvailableMemoryProportion(proportions)
            options.TensorLocations.setOptimizerLocation(
                poptorch.TensorLocationSettings().useOnChipStorage(False))
        
        return options
    
    def set_up_wandb(self, experiment: Experiment)-> None:
        os.environ["WANDB_SILENT"] = "true"
        os.environ["WANDB_DIR"] = "/tmp/wandb"
        Path("/tmp/wandb").mkdir(exist_ok=True)
        wandb.init(
            project="tinystories",
            config=dict(**dataclasses.asdict(experiment)),
            reinit=True,
        )
        wandb.summary["n_parameters"] = sum(p.nelement() for p in self.model.parameters())
    
    def summary(self) -> None:
        print(
            f"Running: {self.name}"
            f"\n{self.train_config}\n{self.model_config}"
            f"\n({sum(p.nelement() for p in self.model.parameters())/1e6:.1f} million parameters)\n",
            file=sys.stderr,
        )
        
    def outputs(self) -> Iterable[Tensor]:
        batches = self.dataset.batches(
            self.model_config.seq_length,
            self.train_config.device_iterations * self.train_config.batch_size,
            "train",
        )
        for host_step, batch in enumerate(islice(batches, self.host_steps)):
            self.poptorch_trainer.setOptimizer(self.optimizer)
            loss = (
                self.poptorch_trainer(batch).sum()
                * self.train_config.compute_batch_size
                / (self.train_config.batch_size * self.train_config.device_iterations)
            )
            self.lr_schedule.step()
            yield host_step * self.train_config.device_iterations, float(loss)

            
    def _train_generator(self):
        it = iter(self.outputs())
        yield next(it)  # compile (before tqdm)
        yield from tqdm.tqdm(
            it, initial=1, total=self.host_steps, desc="training", ncols=120
        )
            
    def train(self, online_plot: bool):
        training_loss = []
        colour = col = np.random.rand(3)
        
        if online_plot:
            fig, ax = plt.subplots(figsize=(8,8))

        for step, loss in self._train_generator():
            training_loss.append(loss)
            
            if self.wandb:
                wandb.log(dict(loss=loss), step=step)
            if online_plot:
                di = self.train_config.device_iterations
                ax.plot(np.arange(0,len(training_loss)*di,di),training_loss, c=colour)
                ax.set(xlabel='Steps', ylabel='Training loss',
                       title='Small Language Model trained on TinyStories')
                display.clear_output(wait=True)
                display.display(plt.gcf())
        
        if online_plot:
            display.clear_output(wait=True)
        
        return training_loss
            
    def validate(self):
        torch.manual_seed(2937852)
        valid_batches = islice(
            self.dataset.batches(self.model.config.seq_length, 64, "validation"), 32
        )
        with torch.no_grad():
            self.validation_loss = float(
                torch.mean(torch.stack(list(map(self.model, valid_batches))))
            )
        if experiment.wandb:
            wandb.log(dict(validation_loss=self.validation_loss), step=self.host_steps)
            
    def run(self, validation: bool, online_plot: bool):
        self.training_loss = self.train(online_plot)
        self.model.float()
        self.model.train(False)
        if validation:
            self.validate()
            print(f"Validation loss: {self.validation_loss}")

## Hyperparameter settings

Here we specify the run configurations of the model that we want to train, that determine the shapes of model components and its overall size. All configurations below are taken from the TinyStories papers and optimized for throughput on IPU. Additional configurations can be defined manually as well, although it's recommended not to deviate too far from the specified proportions below to avoid Out of Memory error.

### Sweep

In [None]:
columns = ['hidden_dim', 'depth', 'pipeline_stages', 'compute_batch_size', 'offloading', 'checkpointing'] 

data = [
    [64, 1, 1, 4, False, False],
    [64, 2, 1, 4, False, False],
    [64, 4, 1, 4, False, False],
    [64, 8, 1, 4, False, False],
    [64, 12, 1, 4, False, False],
    [128, 1, 1, 4, False, False],
    [128, 2, 1, 4, False, False],
    [128, 4, 1, 4, False, False],
    [128, 8, 1, 4, False, False],
    [128, 12, 1, 4, False, False],
    [256, 1, 1, 4, False, False],
    [256, 2, 1, 4, False, False],
    [256, 4, 1, 4, False, False],
    [256, 8, 2, 4, False, True],
    [256, 12, 2, 4, False, True],
    [512, 1, 2, 2, False, True],
    [512, 2, 2, 2, False, True],
    [512, 4, 2, 2, False, True],
    [512, 8, 2, 2, False, True],
    [512, 12, 2, 2, False, True],
    [768, 1, 2, 2, False, False],
    [768, 2, 2, 2, False, True],
    [768, 4, 4, 2, False, True],
    [768, 8, 4, 2, False, True],
    [768, 12, 4, 2, True, True],
    [1024, 1, 2, 2, False, False],
    [1024, 2, 2, 1, False, True],
    [1024, 4, 4, 1, False, True],
    [1024, 8, 4, 1, True, True],
    [1024, 12, 4, 1, True, True],
]

# Create a DataFrame
run_configs = pd.DataFrame(data, columns=columns)

### OR Single run

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython import display
import ipywidgets as widgets

def hyperparameters(hidden_dim, depth, pipeline_stages, compute_batch_size, offloading, checkpointing):
    return [int(hidden_dim), int(depth), int(pipeline_stages), int(compute_batch_size), offloading, checkpointing]

if 'interactive_data' not in globals():
    hd, depth, ps, cbs,offloading,checkpointing = '256', '8', 4, 2, False, True
else:
    hd, depth, ps, cbs,offloading,checkpointing = interactive_data.result
    hd = str(hd)
    depth = str(depth)
    

interactive_data = interactive(hyperparameters, 
                          hidden_dim=widgets.Dropdown(options=['64', '128', '256','512', '768','1024'], layout={'description_width': '1pt'}, 
                                                      value=hd, description='Hidden dim:', disabled=False),
                          depth=widgets.Dropdown(options=['1', '2', '4','8', '12'], 
                                                      value=depth, description='Depth:', disabled=False), 
                          pipeline_stages=widgets.IntSlider(value=ps, min=1, max=4,step=1, 
                                                            description='Pipeline:',disabled=False), 
                          compute_batch_size=widgets.IntSlider(value=cbs, min=1, max=16,step=1, 
                                                            description='Batch size:',disabled=False), 
                          offloading=offloading, 
                          checkpointing=checkpointing)

display.display(interactive_data)
run_configs = pd.DataFrame([interactive_data.result], columns=columns)

## Training loop

In [None]:
dataset = Dataset.create(vocab_size=8192, path=Path("data"))

for _, run_config in  run_configs.iterrows():
    hidden_size, depth, pipeline_stages, compute_batch_size, offloading, checkpointing = run_config
    
    name =  f"h{hidden_size}_l{depth}_model"
    experiment =  Experiment(
        name,
        ModelConfig(hidden_size=hidden_size, depth=depth, seq_length=512, head_size=64, 
                    dtype="half", pipeline_stages=pipeline_stages, checkpointing=checkpointing),
        TrainingConfig(lr=3e-4, steps=int(20000), batch_size=128, compute_batch_size=compute_batch_size, 
                        replicas=NUMBER_OF_IPUS//pipeline_stages, offloading=offloading,
                       device_iterations=100, generation_temperature=0.5),
        profiling=False,
        wandb=False,
    )

    trainer = Trainer(experiment, dataset)
    trainer.summary()
    
    try:
        trainer.run(validation=True, online_plot=True)
    finally:
        trainer.poptorch_trainer.destroy()
        
    if MODEL_SAVING:
        os.makedirs('checkpoints',exist_ok=True)
        torch.save(trainer.model.state_dict(), os.path.join('checkpoints',name))

## Generations

In [None]:
TEST_PROMPTS = [
    "</s>",
    "When Sally woke up, she",
    "When I grow up, I",
    "This was a sunny day",
    """Once upon a time there was a little girl named Lucy. She was very adventurous.
She loved to explore the world around her, especially when it was bright and sunny outside.
One day, while exploring the nearby park, Lucy came across a ladder leaning on a wall.
She was curious to see what's on top, so she climbed the ladder, but when she reached the top, the ladder fell and she was stuck.
A nearby park ranger noticed her and shouted out \"""",
]

In [None]:
generations = [
                (
                    prompt,
                    trainer.model.generate(
                        prompt, 150, experiment.train.generation_temperature
                    ),
                )
                for prompt in TEST_PROMPTS
            ]