# Torch Tune Demo



In [1]:
import sys
import yaml
from typing import Dict, Any
import torch
import torchtune
from torchtune.models.llama2 import llama2_7b

  from .autonotebook import tqdm as notebook_tqdm


### Define a configuration (this would typically be in a YAML file)

In [None]:
# YAML configuration defines parameters for the model, dataset, training, and hardware
config_yaml = """
model:
  name: "llama2_7b"
  pretrained: true
dataset:
  name: "alpaca"
  split: "train[:1000]"
training:
  batch_size: 4
  learning_rate: 1e-4
  num_epochs: 3
  max_seq_length: 128
  weight_decay: 0.01
  warmup_steps: 100
hardware:
  device: "cpu"
  dtype: "float32"
"""

# Save the YAML to a file
with open('config.yaml', 'w') as f:
    f.write(config_yaml)

# simulate the command-line argument parsing
sys.argv = ['ipykernel_launcher.py', '--config', 'config.yaml']

# Load the configuration
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

### Parse the configuration

In [3]:
def main(cfg: Dict[str, Any]):
    # Set up the device and dtype
    # Torch Tune allows you to specify the hardware and data type for training.
    device = torch.device(cfg['hardware']['device'])
    dtype = getattr(torch, cfg['hardware']['dtype'])

    # Initialize the model
    model = llama2_7b()  # Torch Tune is creating a Llama2 7B model.
    model.to(device=device, dtype=dtype)

    # Prepare the dataset 
    # Torchtune isloading the Alpaca dataset and setting up a DataLoader 
    # with specific collation function for sequence-to-sequence fine-tuning (SFT).
    dataset = torchtune.datasets.alpaca_dataset(**cfg['dataset'])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg['training']['batch_size'],
        collate_fn=torchtune.data.padded_collate_sft,
        shuffle=True
    )

    # Set up the optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg['training']['learning_rate'],
        weight_decay=cfg['training']['weight_decay']
    )

    # simple training loop
    for epoch in range(cfg['training']['num_epochs']):
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Save the fine-tuned model
    torch.save(model.state_dict(), "fine_tuned_model.pt")
    print("Fine-tuning complete!")

# Run the main function
main(config)

## So what is TorchTune doing?

- It is providing a standardized way to configure and set up fine-tuning jobs.
- Allows us to use pre-built model architectures (like Llama2) that are ready for fine-tuning.
- Simplifies dataset loading and preparation.
- Handles the complexities of working with large language models, including proper input formatting and output processing.
- Allows for easy configuration of training parameters, hardware usage, and data types.