In [1]:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from onnxruntime.training import artifacts
import onnxruntime.training.api as ort_api
import torch
import onnx
import transformers
import numpy as np
from datasets import load_dataset
from functools import partial
import os


  from .autonotebook import tqdm as notebook_tqdm


# Set parameters

In [33]:
# modelpath="models/TinyLlama-1.1B-intermediate-step-1431k-3T"
modelpath="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
dataset_name="g-ronimo/oasst2_top1_en"
lr=0.00002      # learning rate
bs=2            # batch size
bs_eval=16      # batch size for evals
ga_steps=16     # gradient acc. steps
epochs=4
max_length=2048      # samples max. length
output_dir="out"

# Load model and tokenizer

In [3]:
# model = AutoModelForCausalLM.from_pretrained(
#     modelpath,    
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
#     # attn_implementation="flash_attention_2",
# )

tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    # fast tokenizer sometimes ignores added tokens

# Add ChatML tokens 

In [4]:
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))

1

# Load and prepare OA2 dataset

In [50]:
# Load Dataset
dataset = load_dataset(dataset_name)
dataset = dataset["train"].train_test_split(test_size=0.1)

# chatML Template and tokenize dataset
templates=[
    "<|im_start|>assistant\n{msg}<|im_end|>",
    "<|im_start|>user\n{msg}<|im_end|>"
]
IGNORE_INDEX=-100

def get_position_ids(attention_mask):
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    # Shape: (batch_size, sequence_length)
    return position_ids

# tokenize dataset, set input_ids and attention_mask to train on assistant outputs only
def tokenize(input, max_length):
    input_ids, attention_mask, position_ids, labels = [], [], [], []

    for i,msg in enumerate(input["conversation"]):
        isHuman = msg["role"]=="user"
        msg_chatml=templates[isHuman].format(msg=msg["content"])
        msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False)
    
        input_ids+=msg_tokenized["input_ids"]
        attention_mask+=msg_tokenized["attention_mask"]
        labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]

    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "position_ids": get_position_ids(torch.tensor(attention_mask[:max_length])),
        "labels": labels[:max_length],
    }

dataset_tokenized = dataset.map(
    partial(tokenize, max_length=max_length), 
    batched=False, 
    # num_proc=os.cpu_count(),    # multithreaded
    remove_columns=dataset["train"].column_names  # don't need this anymore, we have tokens from here on
)

Map: 100%|██████████| 4877/4877 [00:06<00:00, 759.76 examples/s]
Map: 100%|██████████| 542/542 [00:00<00:00, 742.31 examples/s]


In [51]:
dataset

DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['conversation'],
        num_rows: 542
    })
})

In [52]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 542
    })
})

In [53]:
# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
    tokens=[e["input_ids"] for e in elements]
    tokens_maxlen=max([len(t) for t in tokens])

    for i,sample in enumerate(elements):
        input_ids=sample["input_ids"]
        labels=sample["labels"]
        attention_mask=sample["attention_mask"]

        pad_len=tokens_maxlen-len(input_ids)

        input_ids.extend( pad_len * [tokenizer.pad_token_id] )   
        labels.extend( pad_len * [IGNORE_INDEX] )    
        attention_mask.extend( pad_len * [0] ) 

    batch={
        "input_ids": torch.tensor( [e["input_ids"] for e in elements] ).numpy(),
        "labels": torch.tensor( [e["labels"] for e in elements] ).numpy(),
        "position_ids": torch.tensor( [e["position_ids"] for e in elements] ).numpy(),
        "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ).numpy(),
    }

    return batch

# Generating artifacts

In [34]:

transformers_model = transformers.LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ignore_mismatched_sizes=True)
tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

batch = {}
for batch_from_dl in dataloader:
    batch = batch_from_dl
    break

# inputs = (torch.tensor(batch['input_ids'], dtype=torch.int64), torch.tensor(batch['attention_mask'], dtype=torch.int64))
# print(inputs[0])
# print(inputs[1].shape)



In [47]:
onnx_model_path = "tinyllama-with-export-script/rank_0_TinyLlama-1.1B-Chat-v1.0_decoder_merged_model_fp32.onnx"
onnx_model = onnx.load(onnx_model_path)
requires_grad = [param.name for param in onnx_model.graph.initializer] # if param.name not in requires_grad]
frozen_params = []
artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    # loss=artifacts.LossType.CrossEntropyLoss,
    artifact_directory="artifacts_generated_full",
    optimizer=artifacts.OptimType.AdamW,
    ort_format=False,
    # loss_input_names=["loss"]
)

# Train

In [48]:
# state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_l1/checkpoint')
# training_model = ort_api.Module('artifacts_generated_l1/training_model_corrected_labels.onnx', state, 'artifacts_generated_l1/eval_model.onnx')
# optimizer = ort_api.Optimizer('artifacts_generated_l1/optimizer_model.onnx', training_model)

state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_full/checkpoint')
training_model = ort_api.Module('artifacts_generated_full/training_model.onnx', state, 'artifacts_generated_full/eval_model.onnx')
optimizer = ort_api.Optimizer('artifacts_generated_full/optimizer_model.onnx', training_model)

In [42]:
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

In [43]:
training_model.input_names()

['input_ids',
 'attention_mask',
 'position_ids',
 'past_key_values.0.key',
 'past_key_values.0.value',
 'labels']

In [44]:
def trainEpoch():
    training_model.train()
    losses = []
    i = 0
    for batch in dataloader:
        print(i, 'out of', len(dataloader))
        forward_inputs = [batch["input_ids"], batch["attention_mask"], batch["labels"]]
        print("input ids shape", batch["input_ids"].shape)
        print("attention mask shape", batch["attention_mask"].shape)
        print("labels shape", batch["labels"].shape)

        loss, _ = training_model(*forward_inputs)
        print('after training acll')
        optimizer.step()
        training_model.lazy_reset_grad()
        losses.append(loss)
        print(loss)
        i += 1

In [45]:
trainEpoch()

0 out of 2439
input ids shape (2, 162)
attention mask shape (2, 162)
labels shape (2, 162)


RuntimeError: C:\a\_work\1\s\orttraining\orttraining\training_api\module.cc:538 onnxruntime::training::api::Module::TrainStep [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : feed names has 31 elements, but feed has 28 elements.


In [20]:
import onnx

model = onnx.load("artifacts_generated_l1/training_model.onnx")


In [21]:
print(model.graph.input[2])
import copy
labels_input = copy.deepcopy(model.graph.input[0])
labels_input.name = "labels"
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
model.graph.input[2].CopyFrom(labels_input)
print(model.graph.input[2].type.tensor_type.shape)

name: "labels"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "Castloss_dim_0"
      }
      dim {
        dim_value: 32000
      }
    }
  }
}

dim {
  dim_param: "batch_size"
}
dim {
  dim_param: "sequence_length"
}



In [22]:
onnx.save(model, "artifacts_generated_l1/training_model_corrected_labels.onnx")