# Set parameters

In [26]:
# modelpath="models/TinyLlama-1.1B-intermediate-step-1431k-3T"
modelpath="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
dataset_name="g-ronimo/oasst2_top1_en"
lr=0.00002      # learning rate
bs=1            # batch size
bs_eval=16      # batch size for evals
ga_steps=16     # gradient acc. steps
epochs=4
# max_length=2048      # samples max. length
max_length=32000      # samples max. length
output_dir="out"

# Load model and tokenizer

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map="auto",
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2",
)

tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    # fast tokenizer sometimes ignores added tokens

  from .autonotebook import tqdm as notebook_tqdm


# Add ChatML tokens 

In [3]:
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
model.resize_token_embeddings(len(tokenizer))
    # pad_to_multiple_of=64)   # phi2 default is 64, see configuration_phi.py
model.config.eos_token_id = tokenizer.eos_token_id

# Load and prepare OA2 dataset

In [31]:
from datasets import load_dataset
from functools import partial
import os

# Load Dataset
dataset = load_dataset(dataset_name)
dataset = dataset["train"].train_test_split(test_size=0.1)

# chatML Template and tokenize dataset
templates=[
    "<|im_start|>assistant\n{msg}<|im_end|>",
    "<|im_start|>user\n{msg}<|im_end|>"
]
IGNORE_INDEX=-100

# tokenize dataset, set input_ids and attention_mask to train on assistant outputs only
def tokenize(input, max_length):
    input_ids, attention_mask, labels = [], [], []

    for i,msg in enumerate(input["conversation"]):
        isHuman = msg["role"]=="user"
        msg_chatml=templates[isHuman].format(msg=msg["content"])
        msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False, max_length=max_length, padding="max_length")
    
        input_ids+=msg_tokenized["input_ids"]
        attention_mask+=msg_tokenized["attention_mask"]
        labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]

    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "labels": labels[:max_length],
    }

dataset_tokenized = dataset.map(
    partial(tokenize, max_length=max_length), 
    batched=False, 
    # num_proc=os.cpu_count(),    # multithreaded
    remove_columns=dataset["train"].column_names  # don't need this anymore, we have tokens from here on
)

Map: 100%|██████████| 4877/4877 [01:36<00:00, 50.79 examples/s] 
Map: 100%|██████████| 542/542 [00:10<00:00, 49.74 examples/s] 


In [28]:
dataset

DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['conversation'],
        num_rows: 542
    })
})

In [22]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 542
    })
})

In [23]:
# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
    tokens=[e["input_ids"] for e in elements]
    tokens_maxlen=max([len(t) for t in tokens])

    for i,sample in enumerate(elements):
        input_ids=sample["input_ids"]
        labels=sample["labels"]
        attention_mask=sample["attention_mask"]

        pad_len=tokens_maxlen-len(input_ids)

        input_ids.extend( pad_len * [tokenizer.pad_token_id] )   
        labels.extend( pad_len * [IGNORE_INDEX] )    
        attention_mask.extend( pad_len * [0] ) 

    batch={
        "input_ids": torch.tensor( [e["input_ids"] for e in elements] ).numpy(),
        "labels": torch.tensor( [e["labels"] for e in elements] ).numpy(),
        "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ).numpy(),
    }

    return batch

# Train

In [12]:
import torch
import onnxruntime.training.api as ort_api

In [11]:
state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_l1/checkpoint')
training_model = ort_api.Module('artifacts_generated_l1/training_model.onnx', state, 'artifacts_generated_l1/eval_model.onnx')
optimizer = ort_api.Optimizer('artifacts_generated_l1/optimizer_model.onnx', training_model)

In [29]:
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

In [16]:
training_model.input_names()

['input_ids', 'attention_mask', 'labels']

In [17]:
def trainEpoch():
    training_model.train()
    losses = []
    i = 0
    for batch in dataloader:
        print(i, 'out of', len(dataloader))
        forward_inputs = [batch["input_ids"], batch["attention_mask"], batch["labels"]]
        print("input ids shape", batch["input_ids"].shape)
        print("attention mask shape", batch["attention_mask"].shape)
        print("labels shape", batch["labels"].shape)

        loss, _ = training_model(*forward_inputs)
        print('after training acll')
        optimizer.step()
        training_model.lazy_reset_grad()
        losses.append(loss)
        print(loss)
        i += 1

In [30]:
trainEpoch()

0 out of 4877
input ids shape (1, 32000)
attention mask shape (1, 32000)
labels shape (1, 32000)


RuntimeError: C:\a\_work\1\s\orttraining\orttraining\training_api\module.cc:632 onnxruntime::training::api::Module::TrainStep [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:'/model/model/embed_tokens/Gather' Status Message: indices element out of data bounds, idx=32000 must be within the inclusive range [-32000,31999]
