<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/pytorch_tpu_starter_mistral_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://www.datacamp.com/tutorial/adamw-optimizer-in-pytorch


# Introduction
PyTorch XLA is a PyTorch library for XLA support. XLA (Accelerated Linear Algebra) is a domain-specific compiler that was originally meant for compiling and accelerating TensorFlow models. However, other packages, like JAX and now PyTorch XLA can compile program with XLA to accelerate code. TPUs can be programmed with XLA programs and PyTorch XLA provides this interface with TPUs by compiling our PyTorch code as XLA programs to run on TPU devices.

In this kernel, I provide an in-depth look into how you can use PyTorch XLA to train a PyTorch model on the TPU for the Feedback Effectiveness Prize competition.


## Installs & Imports


In [None]:
!pip install colab-env -q
import colab_env

In [None]:
!pip install --upgrade transformers -q
!pip install datasets -q
!pip install --upgrade huggingface_hub -q

In [3]:
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import warnings, transformers, logging, torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch.optim.adamw as AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')
logging.disable(logging.WARNING)
import datasets
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.metrics import log_loss
from pathlib import Path
import torch.nn.functional as F
import os
import gc

# PyTorch XLA specific imports
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met
import torch
import torch.nn as nn
import torch.nn.functional as F

## Dataset creation


In [None]:
!ls /content/gdrive/MyDrive/TPU/input/

In [5]:
df = pd.read_csv('/content/gdrive/MyDrive/TPU/input/train.csv')
test_df = pd.read_csv('/content/gdrive/MyDrive/TPU/input/test.csv')

In [6]:
model_id = '/content/gdrive/MyDrive/TPU/Mistral-7B-v0.1'
#tokz = AutoTokenizer.from_pretrained('microsoft/deberta-v3-large', model_max_length=512)
tokz = AutoTokenizer.from_pretrained(model_id, model_max_length=512)

# Check if sep_token is None and use eos_token or a space as fallback
sep = tokz.sep_token if tokz.sep_token is not None else (tokz.eos_token if tokz.eos_token is not None else " ")

df['inputs'] = df.discourse_type + sep +df.discourse_text
new_label = {"discourse_effectiveness": {"Ineffective": 0, "Adequate": 1, "Effective": 2}}
df = df.replace(new_label)
df = df.rename(columns = {"discourse_effectiveness": "labels"})

In [7]:
essay_ids = df.essay_id.unique()
np.random.seed(42)
np.random.shuffle(essay_ids)
essay_ids[:5]
val_prop = 0.2
val_sz = int(len(essay_ids)*val_prop)
val_essay_ids = essay_ids[:val_sz]
is_val = np.isin(df.essay_id, val_essay_ids)
idxs = np.arange(len(df))
val_idxs = idxs[ is_val]
trn_idxs = idxs[~is_val]
len(val_idxs),len(trn_idxs)

(7181, 29584)

In [8]:
def get_dds(df, train=True):
    ds = Dataset.from_pandas(df)
    to_remove = ['discourse_text','discourse_type','inputs','discourse_id','essay_id']
    tok_ds = ds.map(tok_func, batched=True, remove_columns=to_remove)
    if train:
        return DatasetDict({"train":tok_ds.select(trn_idxs), "test": tok_ds.select(val_idxs)})
    else:
        return tok_ds

## Training code



In [None]:
# Define global variables
learning_rate = 1e-4
weight_decay = 1e-2
num_epochs = 10
num_warmup_epochs = 10
BATCH_SIZE = 2 # Reduced batch size
EPOCHS = 1
NUM_WORKERS = 1
OPTIM = "AdamW"
LR = 1e-5
WD = 0.01
WARMUP_PCT = 0.1

import os
os.environ['XLA_FLAGS'] = '--xla_debug_info'

# Dataset creation
df = pd.read_csv('/content/gdrive/MyDrive/TPU/input/train.csv')
test_df = pd.read_csv('/content/gdrive/MyDrive/TPU/input/test.csv')

model_id = '/content/gdrive/MyDrive/TPU/Mistral-7B-v0.1'
tokz = AutoTokenizer.from_pretrained(model_id, model_max_length=512)
sep = tokz.sep_token if tokz.sep_token is not None else (tokz.eos_token if tokz.eos_token is not None else ' ')
df['inputs'] = df.discourse_type + sep + df.discourse_text
new_label = {"discourse_effectiveness": {"Ineffective": 0, "Adequate": 1, "Effective": 2}}
df = df.replace(new_label)
df = df.rename(columns={"discourse_effectiveness": "labels"})
essay_ids = df.essay_id.unique()
np.random.seed(42)
np.random.shuffle(essay_ids)
val_prop = 0.2
val_sz = int(len(essay_ids) * val_prop)
val_essay_ids = essay_ids[:val_sz]
is_val = np.isin(df.essay_id, val_essay_ids)
idxs = np.arange(len(df))
val_idxs = idxs[is_val]
trn_idxs = idxs[~is_val]

# Training code
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained('tokenizer')

def tok_func(x):
    return tokenizer(x["inputs"], padding='max_length', truncation=True)

def get_dds(df, train=True):
    ds = Dataset.from_pandas(df)
    to_remove = ['discourse_text', 'discourse_type', 'inputs', 'discourse_id']
    tok_ds = ds.map(tok_func, batched=True, remove_columns=to_remove)
    if train:
        return DatasetDict({"train": tok_ds.select(trn_idxs), "test": tok_ds.select(val_idxs)})
    else:
        return tok_ds

dds = get_dds(df)
dds.save_to_disk('dds')

# Let's now define our training and validation functions.
def train_loop_fn(data_loader, loss_fn, model, optimizer, device, scheduler):
    model.train()
    for bi, d in enumerate(data_loader):
        for k, v in d.items():
            d[k] = v.to(device)

        optimizer.zero_grad()
        outputs = model(**d)
        loss = loss_fn(outputs['logits'], d['labels'])

        if bi % 50 == 0:
            loss_reduced = xm.mesh_reduce('loss_reduce', loss, lambda x: sum(x) / len(x))
            xm.master_print(f'bi={bi}, loss={loss_reduced}')

        loss.backward()
        xm.optimizer_step(optimizer)

        if scheduler is not None:
            scheduler.step()
    model.eval()

def eval_loop_fn(data_loader, loss_fn, model, device):
    fin_targets = []
    fin_outputs = []
    for bi, d in enumerate(data_loader):
        for k, v in d.items():
            d[k] = v.to(device)

        with torch.no_grad():
            outputs = model(**d)

        targets = d['labels'].cpu().detach().tolist()
        outputs = outputs['logits'].cpu().detach().tolist()

        fin_targets.extend(targets)
        fin_outputs.extend(outputs)

        del targets, outputs
        gc.collect()

    loss = loss_fn(torch.tensor(fin_outputs), torch.tensor(fin_targets))
    loss_reduced = xm.mesh_reduce('loss_reduce', loss, lambda x: sum(x) / len(x))
    xm.master_print(f'val. loss={loss_reduced}')

# Finally, we define a main function that we will run on each of the 8 cores
def main():
    dds = datasets.load_from_disk('dds')
    dds.set_format('torch')
    train_dataset = dds['train']

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=NUM_WORKERS
    )

    valid_dataset = dds['test']
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=NUM_WORKERS
    )

    device = xm.xla_device()

    # Instantiate the model within the main function and enable gradient checkpointing
    # This is crucial for multi-processing with PyTorch XLA
    model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=3, gradient_checkpointing=True)
    model.to(device)
    xm.master_print('done loading model and dataloader')

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)]},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)]}
    ]

    lr = LR * xm.xrt_world_size()
    num_train_steps = int(len(train_dataset) / BATCH_SIZE / xm.xrt_world_size())

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_train_steps * WARMUP_PCT),
        num_training_steps=num_train_steps
    )

    xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    # --- Start Isolation ---
    xm.master_print(f'Process {xm.get_ordinal()}: Starting training loop...')
    # for epoch in range(EPOCHS):
    #     gc.collect()
    #     train_loop_fn(train_data_loader, loss_fn, model, optimizer, device, scheduler)
    #     gc.collect()
    #     eval_loop_fn(valid_data_loader, loss_fn, model, device)
    #     gc.collect()

    #     xm.rendezvous('save_model')
    #     xm.master_print('save model')
    #     xm.save(model.state_dict(), f'xla_trained_model_epoch_{epoch}.pth')
    # --- End Isolation ---

# Kick off the training process using the main function and xmp.spawn
def _mp_fn(rank):
    main()

In [None]:
if __name__ == '__main__':
    #xmp.spawn(_mp_fn, args=(), nprocs=None, start_method='spawn') # Changed start_method to 'spawn'
    xmp.spawn(_mp_fn, args=(), nprocs=None, start_method='forkserver')