# Lesson 5. Model training

Pretraining is very expensive! Please check costs carefully before starting a pretraining project.

You can get a rough estimate your training job cost using [this calculator](https://huggingface.co/training-cluster) from Hugging Face. For training on other infrastructure, e.g. AWS or Google Cloud, please consult those providers for up to date cost estimates. 

In [4]:
import warnings
warnings.filterwarnings('ignore')
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

## 1. Load the model to be trained

Load the upscaled model from the previous lesson:

In [5]:
import torch
from transformers import AutoModelForCausalLM

pretrained_model = AutoModelForCausalLM.from_pretrained(
    "./data/TinySolar-308m-4k-init",
    device_map="cpu", 
    torch_dtype=torch.bfloat16,
    use_cache=False,
)

In [6]:

# Ready for pretraining!
print(pretrained_model)


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06)
    (rotary_emb): 

## 2. Load dataset

Here you'll update two methods on the `Dataset` object to allow it to interface with the trainer. These will be applied when you specify the dataset you created in Lesson 3 as the training data in the next section.

Note that the code has additional comment strings that don't appear in the video. These are to help you understand what each part of the code is doing.

In [7]:
import datasets
from torch.utils.data import Dataset
import torch

class CustomDataset(Dataset):
    def __init__(self, args, split="train"):
        self.args = args
        self.dataset = datasets.load_dataset(
            "parquet",
            data_files=args.dataset_name,
            split=split
        )        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        # Convert the lists to a LongTensor for Pytorch
        input_ids = torch.LongTensor(self.dataset[index]["input_ids"])
        labels = torch.LongTensor(self.dataset[index]["input_ids"])
        # NOTE: we are putting labels = input_ids are because 
        # we wanna perform next token prediction

        # Return the sample as a dictionary
        return {"input_ids": input_ids, "labels": labels}

## 3. Configure Training Arguments

Here you set up the training run. The training dataset you created in Lesson 3 is specified in the Dataset configuration section.

Note: there are comment strings in the cell below that don't appear in the video. These have been included to help you understand what each parameter does.

In [None]:
from dataclasses import dataclass, field
import transformers

@dataclass
class CustomArguments(transformers.TrainingArguments):
    # Dataset configuration
    dataset_name: str = field(
        default="data/packaged_pretrain_dataset.parquet"
    )
    num_proc: int = field(default=1)
    max_seq_length: int = field(default=32)

    # Core training configuration
    optim: str = field(default="adamw_torch")
    max_steps: int = field(default=30)
    per_device_train_batch_size: int = field(default=2)

    # Other training configurations
    seed: int = field(default=0)
    learning_rate: float = field(default=5e-5)
    weight_decay: float = field(default=0)
    warmup_steps: int = field(default=10)
    lr_scheduler_type: str = field(default="linear")
    gradient_checkpointing: bool = field(default=True)
    
    bf16: bool = field(default=True)
    gradient_accumulation_steps: int = field(default=1)
    dataloader_num_workers: int = field(default=0)
    
    # Logging configurationn
    logging_steps: int = field(default=3)
    report_to: str = field(default="none")

    # Saving configurations
    # save_strategy: str = field(default="steps")
    # save_steps: int = field(default=3)
    # save_total_limit: int = field(default=2)

In [9]:
from dataclasses import dataclass, field
import transformers

@dataclass
class CustomArguments(transformers.TrainingArguments):
    dataset_name: str = field(                           # Dataset configuration
        default="./data/packaged_pretrain_dataset.parquet")
    num_proc: int = field(default=16)                     # Number of subprocesses for data preprocessing
    max_seq_length: int = field(default=32)              # Maximum sequence length

    # Core training configurations
    seed: int = field(default=0)                         # Random seed for initialization, ensuring reproducibility
    optim: str = field(default="adamw_torch")            # Optimizer, here it's AdamW implemented in PyTorch
    max_steps: int = field(default=30)                   # Number of maximum training steps
    per_device_train_batch_size: int = field(default=2)  # Batch size per device during training

    # Other training configurations
    learning_rate: float = field(default=5e-5)           # Initial learning rate for the optimizer
    weight_decay: float = field(default=0)               # Weight decay
    warmup_steps: int = field(default=10)                # Number of steps for the learning rate warmup phase
    lr_scheduler_type: str = field(default="linear")     # Type of learning rate scheduler
    gradient_checkpointing: bool = field(default=True)   # Enable gradient checkpointing to save memory
    dataloader_num_workers: int = field(default=0)       # Number of subprocesses for data loading
    bf16: bool = field(default=True)                     # Use bfloat16 precision for training on supported hardware
    gradient_accumulation_steps: int = field(default=1)  # Number of steps to accumulate gradients before updating model weights
    
    # Logging configuration
    logging_steps: int = field(default=3)                # Frequency of logging training information
    report_to: str = field(default="none")               # Destination for logging (e.g., WandB, TensorBoard)

    # Saving configuration
    # save_strategy: str = field(default="steps")          # Can be replaced with "epoch"
    # save_steps: int = field(default=3)                   # Frequency of saving training checkpoint
    # save_total_limit: int = field(default=2)             # The total number of checkpoints to be saved

Parse the custom arguments and set the output directory where the model will be saved: 

In [10]:
parser = transformers.HfArgumentParser(CustomArguments)
args, = parser.parse_args_into_dataclasses(
    args=["--output_dir","output"]
)

Setup the training dataset:

In [11]:
train_dataset = CustomDataset(args=args)

Check the shape of the dataset:

In [12]:
print("Input shape: ", train_dataset[0]['input_ids'].shape)

Input shape:  torch.Size([32])


## 4. Run the trainer and monitor the loss

First, set up a callback to log the loss values during training (note this cell is not shown in the video):

In [13]:
from transformers import Trainer, TrainingArguments, TrainerCallback

# Define a custom callback to log the loss values
class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, log=None, **kwargs):
        if log is not None:
            self.logs.append(log)
    def __init__(self):
        self.logs = []

# Initialize the callback
loss_logging_callback = LossLoggingCallback()

Then, create an instance of the Hugging Face `Trainer` object from the `transformers` library. Call the `train()` method of the trainder to initialize the training run:

In [None]:
trainer = Trainer(
    model=pretrained_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=None,
    callbacks=[loss_logging_callback],
    
    )

trainer.train()

Step,Training Loss
3,4.5163


KeyboardInterrupt: 