In [1]:
import pandas as pd
import pytorch_lightning as pl
import torch
from datasets import Dataset, load_dataset
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AdamW, Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast

# Load data

In [2]:
# Load the dataset in streaming mode
stories_stream = load_dataset(
    "roneneldan/TinyStories", streaming=True, trust_remote_code=True
)
stories_stream

IterableDatasetDict({
    train: IterableDataset({
        features: ['text'],
        num_shards: 4
    })
    validation: IterableDataset({
        features: ['text'],
        num_shards: 1
    })
})

In [3]:
n_rows = 1100

# Get the first 100 rows
rows = list(stories_stream["train"].take(n_rows))

# Count the total number of characters
total_chars = sum(len(row["text"]) for row in rows)
total_chars

1014715

In [4]:
stories = Dataset.from_list(rows)

print(stories)

Dataset({
    features: ['text'],
    num_rows: 1100
})


In [36]:
stories[0]

{'text': 'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'}

In [5]:
batch_size = 100


def batch_iterator():
    for i in tqdm(range(0, len(stories), batch_size)):
        yield stories[i : i + batch_size]["text"]

# Train Tokenizer

In [6]:
base_tokenizer = Qwen2TokenizerFast.from_pretrained(
    "Qwen/Qwen2.5-0.5B", errors="ignore"
)

In [7]:
base_tokenizer.decode([51461], errors="replace")

' �'

In [10]:
vocab_size = 1024
tokenizer = base_tokenizer.train_new_from_iterator(
    batch_iterator(), vocab_size=vocab_size
)

100%|██████████| 11/11 [00:00<00:00, 422.17it/s]









# Test untrained model

## Initialize model

In [41]:
hidden_size = 64

config = Qwen2Config(
    num_hidden_layers=3,
    hidden_size=hidden_size,
    intermediate_size=hidden_size * 4,  # MLP hidden dim, following GPT-2 approach x4
    num_attention_heads=8,
    num_key_value_heads=2,  # if equal to the num_attention heads, the MHA if 1 then MQA, else GQA
    vocab_size=vocab_size,
    max_position_embeddings=512,  # Maximum sequence length
    attention_probs_dropout_prob=0.1,
)

config

Qwen2Config {
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 512,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 8,
  "num_hidden_layers": 3,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.3",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 1024
}

In [42]:
model = Qwen2ForCausalLM(config)

In [43]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

# Calculate the total size in bytes (assuming float32, 4 bytes per parameter)
total_size_bytes = total_params * 4

# Convert to megabytes
total_size_mb = total_size_bytes / (1024 * 1024)

print(f"Total size of the model: {total_size_mb:.2f} MB")

Total number of parameters: 309,984
Total size of the model: 1.18 MB


## Generate text from the model

In [44]:
def decode_tokens_to_dataframe(tokenizer, inputs):
    """Decodes token IDs to tokens and returns them as a transposed DataFrame."""
    decoded_tokens = [tokenizer.decode(token_id) for token_id in inputs[0]]
    token_ids = inputs[0].tolist()

    # Create and return a transposed DataFrame
    df = pd.DataFrame({"Token": decoded_tokens, "Token ID": token_ids})

    return df.T

In [45]:
text = "One day a little girl, wakanda"
inputs = tokenizer(text, return_tensors="pt")

In [46]:
decode_tokens_to_dataframe(tokenizer, inputs["input_ids"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
Token,One,day,a,little,girl,",",wa,k,and,a
Token ID,446,371,272,406,451,25,283,88,711,78


In [47]:
outputs = model.generate(inputs["input_ids"], max_length=16)

In [48]:
decode_tokens_to_dataframe(tokenizer, outputs)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
Token,One,day,a,little,girl,",",wa,k,and,a,Timmy,Timmy,Timmy,Timmy,Timmy,Timmy
Token ID,446,371,272,406,451,25,283,88,711,78,546,546,546,546,546,546


In [49]:
token_to_id = tokenizer.get_vocab()
id_to_token = {v: k for k, v in token_to_id.items()}

> From [Qwen/tokenization_note.md](https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md): The regular tokens are BPE tokens learned from byte sequences of texts encoded using the UTF-8 encoding. While this allows tokenization of all texts and no unknown token exists, it may fall back to using single bytes when tokenizing uncommon texts. You may encounter UTF-8 decoding errors and as the errors are default to replace, thus the replacement character (�) in incomplete generation.

In [50]:
print(id_to_token[189])
print(tokenizer.decode(189))
print(tokenizer.convert_ids_to_tokens(189))

ó
�
ó


In [51]:
print(id_to_token[271])
print(tokenizer.decode(271))
print(tokenizer.convert_ids_to_tokens(271))

Ġt
 t
Ġt


# Train language model

## Create Lightning Data Module

In this case, we are creating a simple Lightning module when we have already loaded the Dataset. We would normally want to use this in a more complex manner, like with an iterableDataset from a series of Parquet files

In [52]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_name: str,
        n_train_rows: int,
        n_val_rows: int,
        batch_size: int,
        max_seq_length: int,
        num_workers: int,
        tokenizer: Qwen2TokenizerFast,
        random_seed: int = 42,
    ):
        """
        :param dataset_name: Name of the dataset.
        :param n_train_rows: Number of training rows.
        :param n_val_rows: Number of validation rows.
        :param batch_size: Batch size.
        :param max_seq_length: Max sequence length.
        :param num_workers: Number of workers.
        :param random_seed: Random seed.
        """
        super().__init__()
        self.dataset_name = dataset_name
        self.n_train_rows = n_train_rows
        self.n_val_rows = n_val_rows
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.num_workers = num_workers
        self.tokenizer = tokenizer
        self.random_seed = random_seed

    def setup(self, stage: str):
        # Load dataset in streaming mode
        ds = load_dataset(self.dataset_name, streaming=True, trust_remote_code=True)

        # Create dataset
        self.train_ds = self._create_dataset(
            ds=ds,
            split="train",
            n_rows=self.n_train_rows,
        )
        self.val_ds = self._create_dataset(
            ds=ds,
            split="validation",
            n_rows=self.n_val_rows,
        )

        # Tokenizer
        # TODO: In reality, we would initialize the tokenizer here
        self.tokenizer = tokenizer

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_ds,
            batch_size=self.batch_size,
            collate_fn=self._collate_batch,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_ds,
            batch_size=self.batch_size,
            collate_fn=self._collate_batch,
            num_workers=self.num_workers,
        )

    def _create_dataset(self, ds, split, n_rows):
        # Download and load N rows
        rows = list(stories_stream[split].take(n_rows))
        return Dataset.from_list(rows)

    def _collate_batch(self, batch):
        # Extract text from batch
        batch_text = [item["text"] for item in batch]

        # Tokenize texts
        batch_tokenized = self.tokenizer(
            batch_text,
            truncation=True,
            padding="longest",
            max_length=self.max_seq_length,
            return_tensors="pt",
        )

        # Prepare labels by shifting input_ids
        labels = batch_tokenized["input_ids"].clone()
        labels[:, :-1] = batch_tokenized["input_ids"][:, 1:]
        labels[:, -1] = self.tokenizer.pad_token_id

        # Add labels to the returned dictionary
        batch_tokenized["labels"] = labels

        return batch_tokenized

### Test the `DataModule`

In [53]:
data_module = DataModule(
    dataset_name="roneneldan/TinyStories",  # The dataset name
    n_train_rows=10,  # For testing, load only 10 rows
    n_val_rows=10,  # For testing, load only 10 validation rows
    batch_size=2,  # Smaller batch size for testing
    max_seq_length=128,  # Choose a max sequence length (e.g., 128)
    num_workers=0,  # No need for multiple workers in a test scenario
    tokenizer=tokenizer,  # Pass the tokenizer here
    random_seed=42,  # Optional, for reproducibility
)

# Set up the data module
data_module.setup(stage="fit")

# Get the first batch from the train dataloader
train_dataloader = data_module.train_dataloader()

# Iterate over the dataloader to get the first batch
first_batch = next(iter(train_dataloader))

# Print the first batch to inspect it
print(first_batch)

{'input_ids': tensor([[446, 371,  25, 272, 406, 451, 572, 398, 616, 272, 831, 322, 333, 326,
         758,  27, 325, 789, 318, 294, 291, 602,  83, 477,  98,  89,  97, 280,
         377, 353, 318, 883, 318, 294, 391, 301,  93,  27, 398, 461, 280, 902,
         276, 831, 322, 353, 326, 399,  25, 356, 348, 480, 430, 100, 272, 433,
          97, 299, 361, 326, 391, 327,  97, 319, 370, 491, 280, 326, 399, 278,
         343,  25, 350, 810,  25, 346, 616, 745, 831, 322,  27, 880, 314, 366,
         902, 318, 353, 516, 278, 430, 100, 625, 391, 327,  97, 610, 900, 399,
         508, 278, 343,  25, 350, 922,  25, 398,  25, 373, 481, 902, 276, 831,
         322, 278, 966, 101, 637, 391, 327,  97, 503,  65,  92, 576,  25, 380,
         391, 606],
        [455, 475, 272, 420,  25, 424, 294, 272, 406, 569, 572, 382,  82, 580,
          27, 382,  82, 580, 560, 280, 449, 848, 278, 377, 333, 276, 749,  27,
         382,  82, 580, 294, 272, 292, 416,  97,  85, 102, 569, 883, 292, 683,
         375, 614,

## Create ModelModule

#### Shifting Inputs and Labels

In language model pretraining, the goal is to predict the next token given the previous tokens. This is achieved by shifting the input and labels:

- **Labels Shift**: `labels[..., 1:]` means "take all dimensions, but slice off the first token from the last dimension."

**Example:**

- **Original input**: `[1, 2, 3, 4]`
- **Shifted input**: `[2, 3, 4]`

This creates an offset, ensuring that each input token predicts the next token.

#### Why Shifting?

In language modeling, the objective is to predict the next token. Therefore, the input `[1, 2, 3]` should predict the next tokens `[2, 3, 4]`. This is distinct from tasks like classification, where labels match exactly.

**Example to Illustrate:**

- **Input**: `[start, "I", "love", "machine"]`
- **Labels**: `["I", "love", "machine", "learning"]`
- **Shifted Input**: `["I", "love", "machine"]`
- **Shifted Labels**: `["love", "machine", "learning"]`

#### Loss Computation

The `CrossEntropyLoss` combines `log softmax` and `NLL (Negative Log Likelihood)` loss. The `reduction='mean'` option averages the loss across all tokens.

- **Logits Transformation**: `logits.view(-1, logits.size(-1))` flattens the logits to a 2D tensor of shape `(batch_size * sequence_length, vocab_size)`.
- **Labels Transformation**: `shift_labels.view(-1)` flattens the labels to a 1D tensor of shape `(batch_size * sequence_length)`.

This setup ensures that the model is trained to predict the next token in the sequence accurately.

In [54]:
# TODO: Make it general, not focused around Qwen


class ModelModule(pl.LightningModule):
    def __init__(
        self,
        qwen_model_config: dict,
        learning_rate: float,
    ):
        super().__init__()
        self.model = Qwen2ForCausalLM(qwen_model_config)
        self.learning_rate = learning_rate

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=True
        )
        return outputs.logits

    def common_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        # Shift labels for next-token prediction
        shift_labels = labels[..., 1:].contiguous()
        shift_input_ids = input_ids[..., 1:].contiguous()
        shift_attention_mask = attention_mask[..., 1:].contiguous()

        # Get logits
        logits = self(shift_input_ids, shift_attention_mask)

        # Compute loss
        loss_fct = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fct(logits.view(-1, logits.size(-1)), shift_labels.view(-1))

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("test_loss", loss)
        return loss

    def predict_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]

        # Generate predictions
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=50,  # Adjust as needed
        )

        return generated_ids

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box usually works fine
        return AdamW(self.parameters(), lr=self.learning_rate)

In [55]:
hidden_size = 64

config = Qwen2Config(
    num_hidden_layers=3,
    hidden_size=hidden_size,
    intermediate_size=hidden_size * 4,  # MLP hidden dim, following GPT-2 approach x4
    num_attention_heads=8,
    num_key_value_heads=2,  # if equal to the num_attention heads, the MHA if 1 then MQA, else GQA
    vocab_size=vocab_size,
    max_position_embeddings=512,  # Maximum sequence length
    attention_probs_dropout_prob=0.1,
)

model_module = ModelModule(
    qwen_model_config=config,
    learning_rate=1e-5,
)

In [56]:
model_module

ModelModule(
  (model): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(1024, 64)
      (layers): ModuleList(
        (0-2): 3 x Qwen2DecoderLayer(
          (self_attn): Qwen2SdpaAttention(
            (q_proj): Linear(in_features=64, out_features=64, bias=True)
            (k_proj): Linear(in_features=64, out_features=16, bias=True)
            (v_proj): Linear(in_features=64, out_features=16, bias=True)
            (o_proj): Linear(in_features=64, out_features=64, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=64, out_features=256, bias=False)
            (up_proj): Linear(in_features=64, out_features=256, bias=False)
            (down_proj): Linear(in_features=256, out_features=64, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((64,), eps=1e-06)
          (post_attention_layernorm): Qwen2RMSNorm((64,), eps=1e-06)

## Set up Trainer

In [57]:
def setup_trainer():
    # Set up callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath="./model_checkpoints",  # Directory to save checkpoints
        filename="qwen-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,  # Save the top 3 models according to the monitored quantity
        verbose=True,
        monitor="val_loss",  # Metric to monitor
        mode="min",  # We want to minimize the validation loss
    )

    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        patience=3,  # Number of epochs with no improvement after which training will be stopped
        verbose=False,
        mode="min",
    )

    # Set up logger
    logger = TensorBoardLogger(save_dir="./tb_logs", name="qwen-pretrain-logs")

    # Create the trainer
    trainer = pl.Trainer(
        max_epochs=10,  # Number of training epochs
        accelerator="auto",  # Automatically use GPU if available
        devices=1,  # Use 1 GPU or CPU
        precision="16-mixed",  # Use mixed precision training
        callbacks=[checkpoint_callback, early_stop_callback],
        logger=logger,
        fast_dev_run=False,  # Set to True for a quick test run
        gradient_clip_val=1.0,  # Gradient clipping
        deterministic=True,  # For reproducibility
    )

    return trainer

In [61]:
# Set random seed for reproducibility
pl.seed_everything(42)

# Create model and data modules (as defined in previous context)
model_module = ModelModule(
    qwen_model_config=config,
    learning_rate=1e-5,
)

data_module = DataModule(
    dataset_name="roneneldan/TinyStories",
    n_train_rows=256,
    n_val_rows=128,
    batch_size=16,
    max_seq_length=128,
    num_workers=0,
    tokenizer=tokenizer,
    random_seed=42,
)

# Setup the data module
data_module.setup(stage="fit")

# Create trainer
trainer = setup_trainer()

Seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


## Run training

In [62]:
# Fit the model
trainer.fit(model=model_module, datamodule=data_module)


  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | Qwen2ForCausalLM | 309 K  | train
---------------------------------------------------
309 K     Trainable params
0         Non-trainable params
309 K     Total params
1.240     Total estimated model params size (MB)
49        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/fernando/Documents/GitHub/tiny-lm/.venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (16) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 16: 'val_loss' reached 6.91907 (best 6.91907), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=00-val_loss=6.92.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 32: 'val_loss' reached 6.90322 (best 6.90322), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=01-val_loss=6.90.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 48: 'val_loss' reached 6.88753 (best 6.88753), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=02-val_loss=6.89.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 64: 'val_loss' reached 6.87166 (best 6.87166), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=03-val_loss=6.87.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 80: 'val_loss' reached 6.85561 (best 6.85561), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=04-val_loss=6.86.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 96: 'val_loss' reached 6.83963 (best 6.83963), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=05-val_loss=6.84.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 112: 'val_loss' reached 6.82397 (best 6.82397), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=06-val_loss=6.82.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 128: 'val_loss' reached 6.80882 (best 6.80882), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=07-val_loss=6.81.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 144: 'val_loss' reached 6.79426 (best 6.79426), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=08-val_loss=6.79.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 160: 'val_loss' reached 6.78031 (best 6.78031), saving model to '/Users/fernando/Documents/GitHub/tiny-lm/examples/model_checkpoints/qwen-epoch=09-val_loss=6.78.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=10` reached.
