In [None]:
!pip install -U peft transformers bitsandbytes accelerate datasets wandb trl flash-attn sentencepiece

In [2]:
!pip list | egrep -w "transformers|datasets|accelerate|peft|bitsandbytes|trl|torch|flash-attn|sentencepiece"

accelerate                        0.28.0             /home/g/accelerate_fork
bitsandbytes                      0.42.0
datasets                          2.18.0
flash-attn                        2.5.6
open-clip-torch                   2.23.0
peft                              0.9.0
sentence-transformers             2.3.1
sentencepiece                     0.2.0
torch                             2.2.1
torch-grammar                     0.3.3
transformers                      4.38.2
trl                               0.7.11


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, set_seed
from peft import LoraConfig
from trl import SFTTrainer, setup_chat_format, DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import torch

set_seed(42)

modelpath = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map = "auto",
    torch_dtype = torch.bfloat16,
    # there's an issue with using FA2 and model.generate() while training: RuntimeError: query and key must have the same dtype
    # attn_implementation = "flash_attention_2",   
)
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)

model, tokenizer = setup_chat_format(model, tokenizer)
if tokenizer.pad_token in [None, tokenizer.eos_token]: 
    tokenizer.pad_token = tokenizer.unk_token

dataset = load_dataset("g-ronimo/oasst2_top4k_en")

training_arguments = TrainingArguments(
    output_dir = "out_OA_TL",
    evaluation_strategy = "steps",
    label_names = ["labels"],
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 1,
    save_steps = 250,
    eval_steps = 250,
    logging_steps = 1, 
    learning_rate = 1e-5,
    num_train_epochs=10,
    lr_scheduler_type = "constant",
    optim = 'paged_adamw_32bit',
    bf16 = True,
    gradient_checkpointing = True,
    group_by_length = True,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset['test'],
    data_collator = DataCollatorForCompletionOnlyLM(
        instruction_template = "<|im_start|>user", 
        response_template = "<|im_start|>assistant", 
        tokenizer = tokenizer, 
        mlm = False),
    max_seq_length = 512,
    args = training_arguments,
)

In [None]:
from transformers import TrainerCallback
from statistics import mean

from semscore import ModelPredictionGenerator, EmbeddingModelWrapper

class SemscoreEvalCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, model, tokenizer, eval_dataloader, **kwargs):

        generator = ModelPredictionGenerator(model = model, tokenizer = tokenizer)
        eval_ds = dataset["test"].select(range(100))
        results = generator.run(dataset = eval_ds)

        em = EmbeddingModelWrapper()
        similarities = em.get_similarities(
            em.get_embeddings( [a["answer_ref"] for a in results] ),
            em.get_embeddings( [a["answer_pred"] for a in results] ),
        )
        cosine_sim = mean(similarities)
        trainer.log({"cosine_sim": cosine_sim})
trainer.add_callback(SemscoreEvalCallback())

In [None]:
trainer.train()