In [1]:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from onnxruntime.training import artifacts
import onnxruntime.training.api as ort_api
import torch
import onnx
import transformers
import numpy as np
from datasets import load_dataset
from functools import partial
import os


  from .autonotebook import tqdm as notebook_tqdm


# Set parameters

In [None]:
# modelpath="models/TinyLlama-1.1B-intermediate-step-1431k-3T"
modelpath="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
dataset_name="g-ronimo/oasst2_top1_en"
lr=0.00002      # learning rate
bs=2            # batch size
bs_eval=16      # batch size for evals
ga_steps=16     # gradient acc. steps
epochs=4
max_length=2048      # samples max. length
output_dir="out"

# Load model and tokenizer

In [None]:
# model = AutoModelForCausalLM.from_pretrained(
#     modelpath,    
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
#     # attn_implementation="flash_attention_2",
# )

tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    # fast tokenizer sometimes ignores added tokens

# Add ChatML tokens 

In [None]:
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))

1

# Load and prepare OA2 dataset

In [None]:
# Load Dataset
dataset = load_dataset(dataset_name)
dataset = dataset["train"].train_test_split(test_size=0.1)

# chatML Template and tokenize dataset
templates=[
    "<|im_start|>assistant\n{msg}<|im_end|>",
    "<|im_start|>user\n{msg}<|im_end|>"
]
IGNORE_INDEX=-100

def get_position_ids(attention_mask):
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    # Shape: (batch_size, sequence_length)
    return position_ids

# tokenize dataset, set input_ids and attention_mask to train on assistant outputs only
def tokenize(input, max_length):
    input_ids, attention_mask, position_ids, labels = [], [], [], []

    for i,msg in enumerate(input["conversation"]):
        isHuman = msg["role"]=="user"
        msg_chatml=templates[isHuman].format(msg=msg["content"])
        msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False)

        input_ids+=msg_tokenized["input_ids"]
        attention_mask+=msg_tokenized["attention_mask"]
        labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]

    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "position_ids": get_position_ids(torch.tensor(attention_mask[:max_length])),
        "labels": labels[:max_length],
    }

dataset_tokenized = dataset.map(
    partial(tokenize, max_length=max_length), 
    batched=False, 
    # num_proc=os.cpu_count(),    # multithreaded
    remove_columns=dataset["train"].column_names  # don't need this anymore, we have tokens from here on
)

Map:  14%|█▍        | 671/4877 [00:01<00:11, 373.41 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2981 > 2048). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 4877/4877 [00:13<00:00, 354.93 examples/s]
Map: 100%|██████████| 542/542 [00:01<00:00, 346.00 examples/s]


In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['conversation'],
        num_rows: 542
    })
})

In [None]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 4877
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'position_ids', 'labels'],
        num_rows: 542
    })
})

In [None]:
# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
    tokens=[e["input_ids"] for e in elements]
    tokens_maxlen=max([len(t) for t in tokens])

    for i,sample in enumerate(elements):
        input_ids=sample["input_ids"]
        labels=sample["labels"]
        position_ids=sample["position_ids"]
        attention_mask=sample["attention_mask"]

        pad_len=tokens_maxlen-len(input_ids)

        input_ids.extend( pad_len * [tokenizer.pad_token_id] )   
        labels.extend( pad_len * [IGNORE_INDEX] )    
        position_ids.extend( pad_len * [1] )
        attention_mask.extend( pad_len * [0] ) 

    batch={
        "input_ids": torch.tensor( [e["input_ids"] for e in elements] ).numpy(),
        "labels": torch.tensor( [e["labels"] for e in elements] ).numpy(),
        "position_ids": torch.tensor( [e["position_ids"] for e in elements] ).numpy(),
        # "position_ids": position_ids.numpy(),
        "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ).numpy(),
    }

    return batch

# Generating artifacts

In [None]:

# transformers_model = transformers.LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ignore_mismatched_sizes=True)
# tokenizer = transformers.AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

batch = {}
for batch_from_dl in dataloader:
    batch = batch_from_dl
    break

# inputs = (torch.tensor(batch['input_ids'], dtype=torch.int64), torch.tensor(batch['attention_mask'], dtype=torch.int64))
# print(inputs[0])
# print(inputs[1].shape)

for x in batch.keys():
    print(x, batch[x].shape, batch[x].max())


input_ids (2, 530) 32002
labels (2, 530) 32002
position_ids (2, 530) 529
attention_mask (2, 530) 1


In [None]:

onnx_model_path = "rank_0_TinyLlama-1.1B-Chat-v1.0_decoder_merged_model_fp32.onnx"
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
onnx_model.graph.input

[name: "input_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "attention_mask"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "position_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "labels"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
]

In [11]:
onnx_model_path = "rank_0_TinyLlama-1.1B-Chat-v1.0_decoder_merged_model_fp32.onnx"
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
requires_grad = [param.name for param in onnx_model.graph.initializer] # if param.name not in requires_grad]
frozen_params = []
artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    # loss=artifacts.LossType.CrossEntropyLoss,
    artifact_directory="artifacts_generated_full_test",
    optimizer=artifacts.OptimType.AdamW,
    ort_format=False,
    # loss_input_names=["loss"]
)

: 

In [13]:
name_graph_output_mapping = {output.name: output for output in onnx_model.graph.output}
print(name_graph_output_mapping)

{'loss': name: "loss"
type {
  tensor_type {
    elem_type: 1
    shape {
    }
  }
}
, 'logits': name: "logits"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
      dim {
        dim_value: 32000
      }
    }
  }
}
}


# Train

In [47]:
# state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_l1/checkpoint')
# training_model = ort_api.Module('artifacts_generated_l1/training_model_corrected_labels.onnx', state, 'artifacts_generated_l1/eval_model.onnx')
# optimizer = ort_api.Optimizer('artifacts_generated_l1/optimizer_model.onnx', training_model)

state = ort_api.CheckpointState.load_checkpoint('artifacts_generated_full_test/checkpoint')
training_model = ort_api.Module('artifacts_generated_full_test/training_model.onnx', state, 'artifacts_generated_full_test/eval_model.onnx')
optimizer = ort_api.Optimizer('artifacts_generated_full_test/optimizer_model.onnx', training_model)

In [46]:
training_onnx = onnx.load('artifacts_generated_full_test/training_model.onnx')
training_onnx.graph.input

[name: "input_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "attention_mask"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "position_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "labels"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "sequence_length"
      }
    }
  }
}
, name: "model.embed_tokens.weight"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 32000
      }
      dim {
        dim_value: 2048
      }
    }
  }
}
, name: "model.layers.0.self_attn.q

In [22]:
for node in training_onnx.graph.node:
    if node.name == "/model/embed_tokens/Gather":
        print(node) # or save it to another variable to access its elements
 

input: "model.embed_tokens.weight"
input: "input_ids"
output: "/model/embed_tokens/Gather_output_0"
name: "/model/embed_tokens/Gather"
op_type: "Gather"
attribute {
  name: "axis"
  type: INT
  i: 0
}



In [12]:
dataloader = torch.utils.data.DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True, collate_fn = collate)

In [23]:
batch = {}
for batch_from_dl in dataloader:
    batch = batch_from_dl
    break

# inputs = (torch.tensor(batch['input_ids'], dtype=torch.int64), torch.tensor(batch['attention_mask'], dtype=torch.int64))
# print(inputs[0])
# print(inputs[1].shape)

for x in batch.keys():
    print(x, batch[x].shape, batch[x].max())

input_ids (2, 518) 32002
labels (2, 518) 32002
position_ids (2, 518) 517
attention_mask (2, 518) 1


In [15]:
training_model.input_names()

['input_ids', 'attention_mask', 'position_ids', 'labels']

In [14]:
def trainEpoch():
    training_model.train()
    losses = []
    i = 0
    for batch in dataloader:
        print(i, 'out of', len(dataloader))
        forward_inputs = [batch["input_ids"], batch["attention_mask"], batch["position_ids"], batch["labels"]]
        # print(batch.keys())
        # print("input ids shape", batch["input_ids"].shape)
        # print("attention mask shape", batch["attention_mask"].shape)
        # print("position_ids shape", batch["position_ids"].shape)
        # print("labels shape", batch["labels"].shape)

        loss, _ = training_model(*forward_inputs)
        # print('after training acll')
        optimizer.step()
        training_model.lazy_reset_grad()
        losses.append(loss)
        print(loss)
        i += 1

In [16]:
trainEpoch()

0 out of 2439


RuntimeError: C:\Users\carolinezhu\Documents\onnxruntime\orttraining\orttraining\training_api\module.cc:632 onnxruntime::training::api::Module::TrainStep [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:'/model/embed_tokens/Gather' Status Message: indices element out of data bounds, idx=32000 must be within the inclusive range [-32000,31999]
Stacktrace:
C:\Users\carolinezhu\Documents\onnxruntime\orttraining\orttraining\python\orttraining_pybind_state.cc(691): onnxruntime_pybind11_state!<lambda_6c99460bc5fb44db08ea39edf4dac239>::operator()+0x4F4
C:\Users\carolinezhu\Documents\onnxruntime\build\Windows\Debug\_deps\pybind11_project-src\include\pybind11\cast.h(1440): onnxruntime_pybind11_state!pybind11::detail::argument_loader<onnxruntime::training::api::Module *,std::vector<pybind11::object,std::allocator<pybind11::object> > const &,std::vector<OrtValue,std::allocator<OrtValue> > &>::call_impl<void,<lambda_6c99460bc5fb44db08ea39edf4dac239> &,0,1,2,pybind11::detail::void_type>+0xA7
C:\Users\carolinezhu\Documents\onnxruntime\build\Windows\Debug\_deps\pybind11_project-src\include\pybind11\cast.h(1415): onnxruntime_pybind11_state!pybind11::detail::argument_loader<onnxruntime::training::api::Module *,std::vector<pybind11::object,std::allocator<pybind11::object> > const &,std::vector<OrtValue,std::allocator<OrtValue> > &>::call<void,pybind11::detail::void_type,<lambda_6c99460bc5fb44db08ea39edf4dac239> &>+0x80
C:\Users\carolinezhu\Documents\onnxruntime\build\Windows\Debug\_deps\pybind11_project-src\include\pybind11\pybind11.h(249): onnxruntime_pybind11_state!<lambda_ef5e7e3c8b97e912a8fb54ee33c997d1>::operator()+0x171
C:\Users\carolinezhu\Documents\onnxruntime\build\Windows\Debug\_deps\pybind11_project-src\include\pybind11\pybind11.h(167): onnxruntime_pybind11_state!<lambda_ef5e7e3c8b97e912a8fb54ee33c997d1>::<lambda_invoker_cdecl>+0x20
C:\Users\carolinezhu\Documents\onnxruntime\build\Windows\Debug\_deps\pybind11_project-src\include\pybind11\pybind11.h(929): onnxruntime_pybind11_state!pybind11::cpp_function::dispatcher+0x13AD
(0): python39!PyArg_ParseTuple_SizeT+0x1D6A
(0): python39!PyObject_MakeTpCall+0xE9
(0): python39!PyErr_FormatFromCauseTstate+0xDD4F
(0): python39!Py_NewReference+0x2C4
(0): python39!PyEval_EvalFrameDefault+0x8BB
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!Py_NewReference+0x92C
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyFunction_Vectorcall+0x23D
(0): python39!PyObject_FastCallDictTstate+0x63
(0): python39!PyObject_Call_Prepend+0x7B
(0): python39!PyWeakref_NewProxy+0x22A8
(0): python39!PyObject_Call+0x1A0
(0): python39!PyEval_EvalFrameDefault+0x167E
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyEval_EvalCodeWithName+0xA9
(0): python39!PyEval_EvalCodeEx+0x9B
(0): python39!PyEval_EvalCode+0x2D
(0): python39!PyFuture_FromASTObject+0x46A
(0): python39!PyFuture_FromASTObject+0x373
(0): python39!Py_NewReference+0x850
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyBytesWriter_Finish+0x251
(0): python39!Py_NewReference+0x2C4
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyType_GenericAlloc+0xB4A
(0): python39!PyVectorcall_Call+0xB8
(0): python39!PyObject_Call+0x13E
(0): python39!PyEval_EvalFrameDefault+0x167E
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyType_GenericAlloc+0xB4A
(0): python39!Py_NewReference+0x2C4
(0): python39!PyEval_EvalFrameDefault+0xE28
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): python39!PyEval_EvalFrameDefault+0x1E9A
(0): python39!PyArg_CheckPositional+0x2116
(0): _asyncio!PyInit__asyncio+0x55C5
(0): _asyncio!PyInit__asyncio+0x53FB
(0): _asyncio!PyInit__asyncio+0x5CA7
(0): _asyncio!PyInit__asyncio+0x306A
(0): python39!PyObject_MakeTpCall+0xE9
(0): python39!PyContext_NewHamtForTests+0x4A
(0): python39!PyContext_NewHamtForTests+0x3C9
(0): python39!PyArg_CheckPositional+0x12E
(0): python39!PyVectorcall_Call+0x5C
(0): python39!PyObject_Call+0x4F
(0): python39!PyObject_Call+0x174
(0): python39!PyEval_EvalFrameDefault+0x167E
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!Py_NewReference+0x546
(0): python39!PyEval_EvalFrameDefault+0x693
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyType_GenericAlloc+0xB4A
(0): python39!Py_NewReference+0x2C4
(0): python39!PyEval_EvalFrameDefault+0x8BB
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyEval_EvalCodeWithName+0xA9
(0): python39!PyEval_EvalCodeEx+0x9B
(0): python39!PyEval_EvalCode+0x2D
(0): python39!PyFuture_FromASTObject+0x46A
(0): python39!PyFuture_FromASTObject+0x373
(0): python39!Py_NewReference+0x850
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!Py_NewReference+0x92C
(0): python39!PyEval_EvalFrameDefault+0x492
(0): python39!PyFunction_Vectorcall+0x946
(0): python39!PyFunction_Vectorcall+0x23D
(0): python39!PyVectorcall_Call+0x5C
(0): python39!PyObject_Call+0x4F
(0): python39!Py_MakePendingCalls+0x4FE
(0): python39!Py_RunMain+0x143
(0): python39!Py_RunMain+0x15
(0): python39!Py_Main+0x6F
(0): python39!Py_Main+0x25
(0): python+0x1268
(0): KERNEL32!BaseThreadInitThunk+0x1D
(0): ntdll!RtlUserThreadStart+0x28


In [20]:
import onnx

model = onnx.load("artifacts_generated_l1/training_model.onnx")


In [21]:
print(model.graph.input[2])
import copy
labels_input = copy.deepcopy(model.graph.input[0])
labels_input.name = "labels"
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
model.graph.input[2].CopyFrom(labels_input)
print(model.graph.input[2].type.tensor_type.shape)

name: "labels"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "Castloss_dim_0"
      }
      dim {
        dim_value: 32000
      }
    }
  }
}

dim {
  dim_param: "batch_size"
}
dim {
  dim_param: "sequence_length"
}



In [22]:
onnx.save(model, "artifacts_generated_l1/training_model_corrected_labels.onnx")