## Package version:
- transformers: 4.47.0.dev0
- torch: 2.4.0+cu121
- torch_xla: 2.4.0+libtpu

### Login to HuggingFace for downloading a gated model and uploading trained model

In [None]:
!huggingface-cli login --token your_hf_token_here

### Install package

In [None]:
!pip install --upgrade pip
!pip install "huggingface_hub[hf_transfer]"
!pip3 install datasets peft -q
!pip install git+https://github.com/huggingface/transformers.git -qq
!pip install --force-reinstall -v torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu -q
!pip install --force-reinstall -v torch_xla[tpu]==2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT

Supported model for TPU SPMD:
- GPTNeoX
- T5
- Llama
- CLIP
- CLIPVision
- Llava
- Gemma
- Mistral
- GPT2
- Qwen2
- Mixtral
- Phi

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch
import torch.nn as nn
import re
import torch_xla.distributed.spmd as xs
from transformers import logging as hf_logging
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
xr.use_spmd()
from torch_xla.distributed.spmd import Mesh
import torch.nn as nn
from tqdm.auto import tqdm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, CLIPConfig, CLIPVisionConfig, LlavaConfig, GemmaConfig,
    MistralConfig, GPT2Config, Qwen2Config, MixtralConfig, PhiConfig, AutoTokenizer, AutoModelForCausalLM,
    AutoModelForSequenceClassification, AutoConfig, Gemma2Config
)
from peft import LoraConfig, TaskType, get_peft_model
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm
import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
import random

# ends with $ to prevent sharding lora parameters
GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),

    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),

    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),

    # output
    ("embed_out", ("fsdp", "mp")),
)
T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),

    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)
LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
CLIP_RULES = (
    ("patch_embedding$", ("fsdp", "mp", None, None)),
    ("position_embedding$", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)$", ("fsdp", "mp")),
    ("self_attn\\.out_proj$", ("mp", "fsdp")),
    ("mlp\\.fc1$", ("fsdp", "mp")),
    ("mlp\\.fc2$", ("mp", "fsdp")),
    ("visual_projection$", ("fsdp", "mp")),
    ("text_projection$", ("fsdp", "mp")),
)
LLAVA_RULES = (
    ("multi_modal_projector\\.linear_1$", ("fsdp", "mp")),
    ("multi_modal_projector\\.linear_2$", ("mp", "fsdp")),
    *LLAMA_RULES,
    *CLIP_RULES,
)
GEMMA_RULES = (
    ("model\\.embed_tokens", ("mp", ("fsdp", "sp"))),
    ("self_attn\\.(q_proj|k_proj|v_proj)", (("fsdp", "sp"), "mp")),
    ("self_attn\\.o_proj", ("mp", ("fsdp", "sp"))),
    ("mlp\\.gate_proj", (("fsdp", "sp"), "mp")),
    ("mlp\\.down_proj", ("mp", ("fsdp", "sp"))),
    ("mlp\\.up_proj", (("fsdp", "sp"), "mp")),
    ("lm_head", (("fsdp", "sp"), "mp")),
    ("score", (("fsdp", "sp"), "mp")),
)
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")),
    ("wpe", ("mp", "fsdp")),

    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),

    # mlp
    ("c_fc", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),

    # output
    ("lm_head", ("fsdp", "mp")),
)
QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
MIXTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("w1", ("fsdp", "mp")),
    ("w2", ("mp", "fsdp")),
    ("w3", ("fsdp", "mp")),
    ("gate", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
)
PHI_RULES = (
    ### (regex) linear modules, (list[sharding methods]) )
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.dense", ("mp", "fsdp")),
    ("mlp\\.fc2", ("mp", "fsdp")),
    ("mlp\\.fc1", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (CLIPConfig, CLIP_RULES),
    (CLIPVisionConfig, CLIP_RULES),
    (LlavaConfig, LLAVA_RULES,),
    (GemmaConfig, GEMMA_RULES),
    (MistralConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (Qwen2Config, QWEN_RULES),
    (MixtralConfig, MIXTRAL_RULES),
    (PhiConfig,PHI_RULES),
    (Gemma2Config, GEMMA_RULES)
]

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning " + str(model.config.__class__))

def partition_module(model, mesh, device='xla', verbose=True):
    partition_specs = find_rule(model)
    model.to(device)

    for name, module in (tqdm(model.named_modules(), desc="partitioning model", disable=not verbose, position=0)):
        if not hasattr(module, "weight") or not isinstance(module.weight, nn.Parameter):
            continue
        find = False
        # print(name, module.__class__.__name__)
        for rule_pattern, spec in partition_specs:
            if re.findall(rule_pattern, name):
                if verbose:
                    print("match", rule_pattern, name, spec)
                    print(f"y match {module}", name, module.weight.size(), module.weight.dim())
                xs.mark_sharding(module.weight, mesh, spec)
                find = True
                break

        if not find:
            if verbose:
                print(f"no match {module}", name, module.weight.size(), module.weight.dim())
            xs.mark_sharding(module.weight, mesh, tuple([None] * module.weight.dim()))

!export USE_TORCH=True
!export XLA_USE_BF16=1
os.environ["PJRT_DEVICE"] = "TPU"
try:
    os.environ.pop('TPU_PROCESS_ADDRESSES')
except:
    pass
hf_logging.set_verbosity_error()

In [None]:
import random
def set_seeds(seed: int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    transformers.set_seed(seed)
    random.seed(seed)

set_seeds(42)

Here i use Llama 3.2 3B model and trained it on unstructured medical corpus text data

In [None]:
!HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download meta-llama/Llama-3.2-3B
model_name = 'meta-llama/Llama-3.2-3B'
BATCH_SIZE = 4
epochs = 2
seq_length = 4096

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

medicaltextbook = []
medicaltextbookconfig = datasets.get_dataset_config_names("zxvix/MedicalTextbook")
for i in medicaltextbookconfig:
    if not any(word in i for word in ["128", "annotate", "paraphrase", "paraphrase_2", "paraphrase_3", "summary", "summary_long"]):
        medicaltextbook.extend(load_dataset("zxvix/MedicalTextbook", i)['train']['text'])

medicaltranscription = load_dataset("rungalileo/medical_transcription_40")["train"]["text"]
medicaltranscription.extend(load_dataset("rungalileo/medical_transcription_40")["test"]["text"])

wikimedicalterms = load_dataset("gamino/wiki_medical_terms")["train"]["page_text"]

fulldataset = medicaltextbook + medicaltranscription + wikimedicalterms
fulldatasettoken = tokenizer(fulldataset)['input_ids']

In [None]:
dataset = [[0]]
for i in tqdm(fulldatasettoken):
    if len(dataset[-1]) + len(i) <= seq_length:
        dataset[-1].extend(i[1:])
    else:
        if len(i) <= seq_length:
            dataset.append(i[1:])
        else:
            j = 0
            while len(i[j:]) > seq_length:
                dataset.append(i[1+j:seq_length+j])
                j += seq_length
random.shuffle(dataset)

In [None]:
total = 0
for i in tqdm(fulldatasettoken):
    total += len(i)
print(f"{total/1000000}M Tokens")

Creating PyTorch dataset and dataloader

In [None]:
class Dataset(torch.utils.data.Dataset):
        def __init__(self,tokenizer,dataset):
            super().__init__()
            self.tokenizer = tokenizer
            self.dataset = dataset
            self.data = self.dataset
        def __len__(self):
            return len(self.data)
        def __getitem__(self, idx):
            text = self.data[idx]
            text = self.tokenizer.decode(text)
            tokenize = self.tokenizer(text, max_length=4096, padding='max_length', truncation=True)
            input_ids = torch.tensor(tokenize['input_ids'])
            attn_masks = torch.tensor(tokenize['attention_mask'])
            labels = torch.tensor(tokenize['input_ids']).clone().detach()
            labels[labels == tokenizer.pad_token_id] = -100
            return (input_ids, attn_masks, labels)
datasetth = Dataset(tokenizer, dataset)
dataloader = torch.utils.data.DataLoader(datasetth, batch_size=BATCH_SIZE)
device = xm.xla_device()
dataloader = pl.MpDeviceLoader(dataloader, device)

Applying SPMD

In [None]:
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
config = transformers.AutoConfig.from_pretrained(model_name)

num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))

partition_module(model, mesh)

Create the training loop and automatically upload the model to HF repo

In [None]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

def save_model(model, tokenizer, optimizer, scheduler):
    model = model.cpu()
    model.push_to_hub("Llama-3.2-3B-contpretrain-Medical")

def train_loop(training_loader, optimizer, scheduler):
    model.train()
    for i in range(epochs):
        step = 1
        for batch in tqdm(training_loader):
            input_ids, attention_mask, labels = batch[0], batch[1], batch[2]
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(attention_mask, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            if step % 100 == 0:
                xm.master_print(outputs.loss.detach().cpu().item())
            del input_ids, attention_mask
            outputs.loss.backward()
            del outputs
            optimizer.step()
            xm.mark_step()
            scheduler.step()
            optimizer.zero_grad()
            step += 1
    save_model(model, tokenizer, optimizer, scheduler)

base_lr = 6e-5
warmup_steps = 100
epsilon = 1e-8

optimizer = AdamW(model.parameters(), lr = base_lr, eps = epsilon)

total_steps = len(dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)
def train():
    print(f"""Training {model}""")
    train_loop(dataloader, optimizer, scheduler)

In [None]:
train()