In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, set_seed
from peft import LoraConfig
from trl import SFTTrainer, setup_chat_format, DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import torch

set_seed(42)

# change me!
modelpath = "models/TinyLlama-1.1B-intermediate-step-1431k-3T"

model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map = "auto",
    torch_dtype = torch.bfloat16,
    attn_implementation = "flash_attention_2",   
)
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)

model, tokenizer = setup_chat_format(model, tokenizer)
if tokenizer.pad_token in [None, tokenizer.eos_token]: 
    tokenizer.pad_token = tokenizer.unk_token

dataset = load_dataset("g-ronimo/oasst2_top4k_en")

training_arguments = TrainingArguments(
    output_dir = "out_OA_TL",
    evaluation_strategy = "steps",
    label_names = ["labels"],
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 1,
    save_steps = 250,
    eval_steps = 250,
    logging_steps = 1, 
    learning_rate = 1e-5,
    num_train_epochs = 3,
    lr_scheduler_type = "constant",
    optim = 'paged_adamw_32bit',
    bf16 = True,
    gradient_checkpointing = True,
    group_by_length = True,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset['test'],
    data_collator = DataCollatorForCompletionOnlyLM(
        instruction_template = "<|im_start|>user", 
        response_template = "<|im_start|>assistant", 
        tokenizer = tokenizer, 
        mlm = False),
    max_seq_length = 512,
    dataset_kwargs = dict(add_special_tokens = False),
    args = training_arguments,
)

trainer.train()