In [1]:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from onnxruntime.training import artifacts
import onnxruntime.training.api as ort_api
import torch
import onnx
import transformers
import numpy as np
from datasets import load_dataset
from functools import partial
import os


  from .autonotebook import tqdm as notebook_tqdm


# Set parameters

In [2]:
# modelpath="models/TinyLlama-1.1B-intermediate-step-1431k-3T"
modelpath="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
dataset_name="g-ronimo/oasst2_top1_en"
lr=0.00002      # learning rate
bs=2            # batch size
bs_eval=16      # batch size for evals
ga_steps=16     # gradient acc. steps
epochs=4
max_length=2048      # samples max. length
output_dir="out"

# Load model and tokenizer

In [3]:
# model = AutoModelForCausalLM.from_pretrained(
#     modelpath,    
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
#     # attn_implementation="flash_attention_2",
# )

tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    # fast tokenizer sometimes ignores added tokens

# Add ChatML tokens 

In [4]:
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))

1

# Load and prepare OA2 dataset

In [5]:
# Load Dataset
dataset = load_dataset(dataset_name)
dataset = dataset["train"].train_test_split(test_size=0.1)

# chatML Template and tokenize dataset
templates=[
    "<|im_start|>assistant\n{msg}<|im_end|>",
    "<|im_start|>user\n{msg}<|im_end|>"
]
IGNORE_INDEX=-100

def get_position_ids(attention_mask):
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    # Shape: (batch_size, sequence_length)
    return position_ids

# tokenize dataset, set input_ids and attention_mask to train on assistant outputs only
def tokenize(input, max_length):
    input_ids, attention_mask, position_ids, labels = [], [], [], []

    for i,msg in enumerate(input["conversation"]):
        isHuman = msg["role"]=="user"
        msg_chatml=templates[isHuman].format(msg=msg["content"])
        msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False)
    
        input_ids+=msg_tokenized["input_ids"]
        attention_mask+=msg_tokenized["attention_mask"]
        labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]

    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "position_ids": get_position_ids(torch.tensor(attention_mask[:max_length])),
        "labels": labels[:max_length],
    }

dataset_tokenized = dataset.map(
    partial(tokenize, max_length=max_length), 
    batched=False, 
    # num_proc=os.cpu_count(),    # multithreaded
    remove_columns=dataset["train"].column_names  # don't need this anymore, we have tokens from here on
)

Map:  24%|██▎       | 1149/4877 [00:04<00:15, 240.74 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2222 > 2048). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 4877/4877 [00:19<00:00, 246.97 examples/s]
Map: 100%|██████████| 542/542 [00:02<00:00, 249.32 examples/s]


In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['conversation'],
        num_rows: 542
    })
})

In [7]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 542
    })
})

In [8]:
# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
    tokens=[e["input_ids"] for e in elements]
    tokens_maxlen=max([len(t) for t in tokens])

    for i,sample in enumerate(elements):
        input_ids=sample["input_ids"]
        labels=sample["labels"]
        position_ids=sample["position_ids"]
        attention_mask=sample["attention_mask"]

        pad_len=tokens_maxlen-len(input_ids)

        input_ids.extend( pad_len * [tokenizer.pad_token_id] )   
        labels.extend( pad_len * [IGNORE_INDEX] )    
        position_ids.extend( pad_len * [1] )
        attention_mask.extend( pad_len * [0] ) 

    batch={
        "input_ids": torch.tensor( [e["input_ids"] for e in elements] ).numpy(),
        "labels": torch.tensor( [e["labels"] for e in elements] ).numpy(),
        "position_ids": torch.tensor( [e["position_ids"] for e in elements] ).numpy(),
        # "position_ids": position_ids.numpy(),
        "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ).numpy(),
    }

    return batch

# Generating artifacts

In [9]:

# transformers_model = transformers.LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ignore_mismatched_sizes=True)
# tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

batch = {}
for batch_from_dl in dataloader:
    batch = batch_from_dl
    break

# inputs = (torch.tensor(batch['input_ids'], dtype=torch.int64), torch.tensor(batch['attention_mask'], dtype=torch.int64))
# print(inputs[0])
# print(inputs[1].shape)

for x in batch.keys():
    print(x, batch[x].shape)


input_ids (2, 410)
labels (2, 410)
position_ids (2, 410)
attention_mask (2, 410)


In [12]:
onnx_model_path = "tinyllama-full-with-export-script/rank_0_TinyLlama-1.1B-Chat-v1.0_decoder_merged_model_fp32.onnx"
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
requires_grad = [param.name for param in onnx_model.graph.initializer] # if param.name not in requires_grad]
frozen_params = []
artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    # loss=artifacts.LossType.CrossEntropyLoss,
    artifact_directory="artifacts_generated_full",
    optimizer=artifacts.OptimType.AdamW,
    ort_format=False,
    loss_input_names=["loss"]
)

ValidationError: Data of TensorProto ( tensor name: model.embed_tokens.weight) should be stored in rank_0_TinyLlama-1.1B-Chat-v1.0_decoder_merged_model_fp32.onnx.data, but it doesn't exist or is not accessible.

In [13]:
name_graph_output_mapping = {output.name: output for output in onnx_model.graph.output}
print(name_graph_output_mapping)

{'loss': name: "loss"
type {
  tensor_type {
    elem_type: 1
    shape {
    }
  }
}
, 'logits': name: "logits"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
      dim {
        dim_value: 32000
      }
    }
  }
}
}


# Train

In [56]:
# state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_l1/checkpoint')
# training_model = ort_api.Module('artifacts_generated_l1/training_model_corrected_labels.onnx', state, 'artifacts_generated_l1/eval_model.onnx')
# optimizer = ort_api.Optimizer('artifacts_generated_l1/optimizer_model.onnx', training_model)

state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_full/checkpoint')
training_model = ort_api.Module('artifacts_generated_full/training_model.onnx', state, 'artifacts_generated_full/eval_model.onnx')
optimizer = ort_api.Optimizer('artifacts_generated_full/optimizer_model.onnx', training_model)

In [71]:
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

In [72]:
batch = {}
for batch_from_dl in dataloader:
    batch = batch_from_dl
    break

# inputs = (torch.tensor(batch['input_ids'], dtype=torch.int64), torch.tensor(batch['attention_mask'], dtype=torch.int64))
# print(inputs[0])
# print(inputs[1].shape)

for x in batch.keys():
    print(x, batch[x].shape)

input_ids (2, 1029)
labels (2, 1029)
position_ids (2, 1029)
attention_mask (2, 1029)


In [58]:
training_model.input_names()

['input_ids', 'attention_mask', 'position_ids', 'labels']

In [79]:
def trainEpoch():
    training_model.train()
    losses = []
    i = 0
    for batch in dataloader:
        print(i, 'out of', len(dataloader))
        forward_inputs = [batch["input_ids"], batch["attention_mask"], batch["position_ids"], batch["labels"]]
        # print(batch.keys())
        # print("input ids shape", batch["input_ids"].shape)
        # print("attention mask shape", batch["attention_mask"].shape)
        # print("position_ids shape", batch["position_ids"].shape)
        # print("labels shape", batch["labels"].shape)

        loss, _ = training_model(*forward_inputs)
        # print('after training acll')
        optimizer.step()
        training_model.lazy_reset_grad()
        losses.append(loss)
        print(loss)
        i += 1

In [80]:
trainEpoch()

0 out of 2439
11.074423
1 out of 2439
10.598811
2 out of 2439
10.149166
3 out of 2439
11.018109
4 out of 2439
10.19532
5 out of 2439
10.395526
6 out of 2439
11.050265
7 out of 2439
9.975668
8 out of 2439
9.344571
9 out of 2439
9.393676
10 out of 2439
9.087124
11 out of 2439
8.929982
12 out of 2439
9.041899
13 out of 2439
8.3379345
14 out of 2439
8.945558
15 out of 2439
8.974789
16 out of 2439
8.334281
17 out of 2439
8.903706
18 out of 2439
8.9175825
19 out of 2439
8.534397
20 out of 2439
8.369703
21 out of 2439
7.768069
22 out of 2439
7.896776
23 out of 2439
8.12299
24 out of 2439
8.340614
25 out of 2439
8.178937
26 out of 2439
8.498382
27 out of 2439
7.4799457
28 out of 2439
4.9762287
29 out of 2439
7.3726296
30 out of 2439
8.418573
31 out of 2439
7.931908
32 out of 2439
7.957644
33 out of 2439
7.294777
34 out of 2439
7.556009
35 out of 2439
7.756798
36 out of 2439
7.711334
37 out of 2439
7.4318843
38 out of 2439
7.2211747
39 out of 2439
7.7063537
40 out of 2439
7.2114935
41 out of 24

In [20]:
import onnx

model = onnx.load("artifacts_generated_l1/training_model.onnx")


In [21]:
print(model.graph.input[2])
import copy
labels_input = copy.deepcopy(model.graph.input[0])
labels_input.name = "labels"
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
model.graph.input[2].CopyFrom(labels_input)
print(model.graph.input[2].type.tensor_type.shape)

name: "labels"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "Castloss_dim_0"
      }
      dim {
        dim_value: 32000
      }
    }
  }
}

dim {
  dim_param: "batch_size"
}
dim {
  dim_param: "sequence_length"
}



In [22]:
onnx.save(model, "artifacts_generated_l1/training_model_corrected_labels.onnx")