In [1]:
!pip install --upgrade pip -q
!pip install transformers datasets sentencepiece peft -q
!pip install huggingface_hub -q
!pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT

[0m

In [3]:
!export USE_TORCH=True # To use transformers library in TPU

In [14]:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model

In [5]:
import sys
import importlib
sys.path.append('')
fsdp_util = importlib.import_module('utils.fsdp')
importlib.reload(fsdp_util)

<module 'utils.fsdp' from '/home/tunerX/utils/fsdp.py'>

In [6]:
from huggingface_hub import login

login(token="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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [7]:
MODEL = "TinyLlama/TinyLlama-1.1B-step-50K-105b" # "meta-llama/Llama-2-7b-hf"

In [8]:
config = AutoConfig.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL)
tokenizer = AutoTokenizer.from_pretrained(MODEL)

if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token
    config.pad_token_id = tokenizer.pad_token_id



In [9]:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, 
                         inference_mode=False, 
                         r=8, 
                         lora_alpha=32, 
                         lora_dropout=0.1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.10229075496156657


In [10]:
fsdp_util.apply_fsdp(model, ["LlamaDecoderLayer"])

I0000 00:00:1721406780.209360   10472 tpu_initializer_framework_helper.cc:78] Libtpu path is: /usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1721406783.256449   10472 pjrt_c_api_client.cc:110] PjRtCApiClient created.


In [11]:
# Define Alpaca prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction: {}

### Input: {}

### Response: {}"""

EOS_TOKEN = tokenizer.eos_token

# Define formatting function
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

# Load and preprocess the dataset
dataset = load_dataset("yahma/alpaca-cleaned", split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)

In [15]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

ds = dataset.train_test_split(test_size=0.15)
ds['train'] = ds['train'].map(tokenize_function, batched=True, remove_columns=dataset.column_names)
ds['test'] = ds['test'].map(tokenize_function, batched=True, remove_columns=dataset.column_names)

# Create DataLoader
train_dataloader = torch.utils.data.DataLoader(
    ds['train'],
    shuffle=True,
    batch_size=1,
    collate_fn=default_data_collator,
)

val_dataloader = torch.utils.data.DataLoader(
    ds['test'],
    shuffle=True,
    batch_size=1,
    collate_fn=default_data_collator,
)

Map: 100%|██████████| 43996/43996 [00:25<00:00, 1714.63 examples/s]
Map: 100%|██████████| 7764/7764 [00:04<00:00, 1761.43 examples/s]


In [16]:
# Test the DataLoader
print("Testing DataLoader:")
batch = next(iter(train_dataloader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: shape {v.shape}, dtype {v.dtype}")
    else:
        print(f"{k}: {type(v)}")

Testing DataLoader:
input_ids: shape torch.Size([1, 512]), dtype torch.int64
attention_mask: shape torch.Size([1, 512]), dtype torch.int64


In [22]:
model.print_trainable_parameters()

trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.10229075496156657


In [39]:
device = xm.xla_device()
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
model.train()
model = model.to(device)

In [23]:
xm.xrt_world_size()

1

In [24]:
device

device(type='xla', index=0)

In [40]:
train_device_loader = pl.MpDeviceLoader(train_dataloader, device)

In [41]:
batch = next(iter(train_device_loader))

In [42]:
labels = batch['input_ids'].clone()
labels[:, :-1] = batch['input_ids'][:, 1:]
labels[:, -1] = -100
batch['labels'] = labels

In [43]:
batch = {k: v.to(device) for k, v in batch.items()}

In [45]:
model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])

KeyboardInterrupt: 

In [None]:
def train(args):
    optim = torch.optim.Adam(model.parameters(), lr=0.0001)

    for i in [2048]: # [128, 256, 512, 1024, 2048]:
        print("forwarding...", i)
        input_ids = torch.arange(i).unsqueeze(0).to(xm.xla_device())

        output = model(input_ids=input_ids)

        # loss = output.last_hidden_state.mean()
        loss = output.logits.mean()
        print(loss)

        loss.backward()
        optim.step()

In [None]:
model = model.cpu()
print('now saving the model')
model.push_to_hub(
    "felarof01/llama3-test", 
    tokenizer=tokenizer,
    private=False,
    create_pr=False,
    max_shard_size="2GB", # Sharding isn't as important as before since hardware is better now but who cares anyway
)