In [1]:
# !pip install --upgrade pip -q
# !pip install transformers datasets sentencepiece peft -q
# !pip install huggingface_hub -q
# !pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT

In [2]:
!export USE_TORCH=True # To use transformers library in TPU

In [3]:
import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import sys
import importlib
sys.path.append('')
fsdp_util = importlib.import_module('utils.fsdp')
importlib.reload(fsdp_util)

<module 'utils.fsdp' from '/home/tunerX/utils/fsdp.py'>

In [5]:
from huggingface_hub import login

login(token="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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [6]:
# "TinyLlama/TinyLlama-1.1B-step-50K-105b"
# "meta-llama/Meta-Llama-3-8B" 
# "meta-llama/Llama-2-7b-hf" 
MODEL_NAME = "meta-llama/Meta-Llama-3-8B" 

In [7]:
def init_model(*, model_name):
    config = AutoConfig.from_pretrained(model_name, use_auth_token=True)
    model = AutoModelForCausalLM.from_config(config)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        config.pad_token_id = tokenizer.pad_token_id
    
    return model, tokenizer

In [8]:
def apply_lora(model):
    # TODO: pass lora config as argument to function.
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
                             inference_mode=False,
                             r=8,
                             lora_alpha=32,
                             lora_dropout=0.1)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model

In [9]:
def fsdp_wrapper(x):
    return FSDP(x, shard_param_on_dim_0=True, pin_layout_in_collective_ops=True, disable_reshard_on_root=False, reshard_after_forward=True)

In [10]:
def apply_fsdp(model):
    # Apply on layers within model.
    fsdp_util.apply_fsdp(model, ["LlamaDecoderLayer"])

    # Apply on the model itself.
    model = fsdp_wrapper(model)
    return model

In [11]:
def get_dataset(*, tokenizer, batch_size: int = 1):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    # Tokenize the dataset.
    def _tokenize(examples):
        return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=default_data_collator,
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=default_data_collator,
    )

    return train_dataloader, test_dataloader

In [12]:
def create_dummy_batch(batch_size=1, sequence_length=128):
    device = xm.xla_device()
    
    input_ids = torch.ones(batch_size, sequence_length, dtype=torch.long).to(device)
    attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long).to(device)
    labels = torch.ones(batch_size, sequence_length, dtype=torch.long).to(device)
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

In [13]:
def train(index):
    # dist.init_process_group('xla', init_method='xla://')
    torch.manual_seed(99)
    device = xm.xla_device()
    
    model, tokenizer = init_model(model_name=MODEL_NAME)
    model = apply_lora(model)
    model = apply_fsdp(model)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    
    train_dataloader, test_dataloader = get_dataset(tokenizer=tokenizer, batch_size=1)
    train_dataloader, test_dataloader = pl.MpDeviceLoader(train_dataloader, device), pl.MpDeviceLoader(test_dataloader, device)
    
    for epoch in range(1):
        xm.master_print(f'Epoch {epoch} train begin {test_utils.now()}')
        
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()

            # labels = batch['input_ids'].clone()
            # labels[:, :-1] = batch['input_ids'][:, 1:]
            # labels[:, -1] = -100
            # batch = {k: v.to(device) for k, v in batch.items()}
            
            output = model(input_ids=batch['input_ids'],
                           attention_mask=batch['attention_mask'],
                           labels=batch['input_ids'])
            loss = output.loss
            
            loss.backward()
            optimizer.step()
            
            xm.master_print(f'Loss: {loss:.2f}')

In [14]:
# if __name__ == '__main__':
#     xmp.spawn(train)

In [15]:
xmp.spawn(train, start_method="fork")

I0000 00:00:1721519999.904513   51782 tpu_initializer_framework_helper.cc:78] Libtpu path is: /usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721519999.905489   51783 tpu_initializer_framework_helper.cc:78] Libtpu path is: /usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721519999.907571   51784 tpu_initializer_framework_helper.cc:78] Libtpu path is: /usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721519999.911000   51785 tpu_initializer_framework_helper.cc:78] Libtpu path is: /usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721520002.188099   51783 pjrt_c_api_client.cc:110] PjRtCApiClient created.
I0000 00:00:1721520002.195237   51782 pjrt_c_api_client.cc:110] PjRtCApiClient created.
I0000 00:00:1721520002.196222   51785 pjrt_c_api_client.cc:110] PjRtCApiClient created.
I0000 00:00:1721520002.343996   51784 pjrt_c_api_client.cc:110] PjRtCApiClient created.
Special tokens h

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259


Map:  99%|█████████▊| 51000/51760 [00:01<00:00, 45705.18 examples/s]
Map:   0%|          | 0/51760 [00:00<?, ? examples/s][A
Map:  12%|█▏        | 6000/51760 [00:00<00:00, 47404.34 examples/s][A
Map:  23%|██▎       | 12000/51760 [00:00<00:00, 49469.87 examples/s][A
Map: 100%|██████████| 51760/51760 [00:02<00:00, 21869.50 examples/s][A

Map:   0%|          | 0/43996 [00:00<?, ? examples/s].58 examples/s][A
Map:  58%|█████▊    | 30000/51760 [00:01<00:01, 21197.99 examples/s][A
Map:   2%|▏         | 1000/43996 [00:00<00:31, 1346.25 examples/s]s][A
Map:  71%|███████▏  | 37000/51760 [00:01<00:01, 14461.18 examples/s][A
Map:   5%|▍         | 2000/43996 [00:01<00:30, 1392.09 examples/s]s][A
Map:  83%|████████▎ | 43000/51760 [00:02<00:00, 11667.08 examples/s][A
Map:   7%|▋         | 3000/43996 [00:02<00:30, 1362.46 examples/s]] [A
Map:  93%|█████████▎| 48000/51760 [00:03<00:00, 10218.76 examples/s][A
Map:   9%|▉         | 4000/43996 [00:02<00:29, 1375.75 examples/s]] [A
Map: 100%|

Epoch 0 train begin 00:16:14







Map: 100%|██████████| 43996/43996 [00:57<00:00, 763.37 examples/s][A


Loss: 12.48


RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 224.00M. That was not possible. There are 181.39M free.; (0x0x0_HBM0)

In [None]:
# model = model.cpu()
# print('now saving the model')
# model.push_to_hub(
#     "felarof01/llama3-test", 
#     tokenizer=tokenizer,
#     private=False,
#     create_pr=False,
#     max_shard_size="2GB", # Sharding isn't as important as before since hardware is better now but who cares anyway
# )