In [None]:
import logging
from dataclasses import dataclass, field
import os
import sys
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger

import transformers
from transformers import (
    MODEL_FOR_MASKED_LM_MAPPING,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    set_seed,
)
from transformers.trainer_utils import seed_worker

from peft import LoraConfig, get_peft_model

from llm2vec import LLM2Vec
from llm2vec.dataset.utils import load_dataset
from llm2vec.loss.utils import load_loss

from tqdm import tqdm

MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [None]:
from bg2vec.arguments import simcse_parser
    
model_args, data_args, training_args, custom_args = simcse_parser.parse_json_file(
        "model_configurations/bggpt-7b-simcse.json"
    )
if training_args.ddp_find_unused_parameters:
    kwargs = [
        DistributedDataParallelKwargs(
            dim=0,
            broadcast_buffers=True,
            bucket_cap_mb=25,
            find_unused_parameters=True,
            check_reduction=False,
            gradient_as_bucket_view=False,
        )
    ]
else:
    kwargs = []

In [None]:
accelerator = Accelerator(kwargs_handlers=kwargs)

set_seed(training_args.seed)

if training_args.gradient_checkpointing:
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

In [None]:
import datasets

In [None]:
from bg2vec.data_util import PairedDataset, load_raw_datasets

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
datasets = load_raw_datasets(data_args, model_args)
train_dataset = PairedDataset(datasets['train'])
valid_dataset = PairedDataset(datasets['validation'])

In [None]:
train_examples = [
    train_dataset[i]
    for i in tqdm(
        range(len(train_dataset)),
        desc="Loading train examples...",
        disable=not accelerator.is_main_process,
    )
]
validation_examples = [
    valid_dataset[i]
    for i in tqdm(
        range(len(valid_dataset)),
        desc="Loading train examples...",
        disable=not accelerator.is_main_process,
    )
]
torch_dtype = (
    model_args.torch_dtype
    if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
)
    

In [None]:
model_args

In [None]:
model = LLM2Vec.from_pretrained(
    base_model_name_or_path=model_args.model_name_or_path,
    enable_bidirectional=model_args.bidirectional,
    peft_model_name_or_path=model_args.peft_model_name_or_path,
    merge_peft=True,
    pooling_mode=model_args.pooling_mode,
    max_length=model_args.max_seq_length,
    torch_dtype=torch_dtype,
    attn_implementation=model_args.attn_implementation,
    attention_dropout=custom_args.simcse_dropout,
    cache_dir="/data/bggpt/"
)

In [None]:
model.tokenizer.model_max_length = 512

In [None]:
from bg2vec.model import initialize_peft

In [None]:
# model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model
model.model = initialize_peft(
    model.model,
    lora_r=custom_args.lora_r,
    lora_alpha=2 * custom_args.lora_r,
    lora_dropout=custom_args.lora_dropout,
)

In [None]:
from bg2vec.training import SimCSEDefaultCollator, SimCSETrainer

In [None]:
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
data_collator = SimCSEDefaultCollator(model)

In [None]:
trainer = SimCSETrainer(
    model=model,
    args=training_args,
    train_dataset=train_examples,
    eval_dataset=validation_examples,
    data_collator=data_collator,
    tokenizer=model.tokenizer,
    loss_function=train_loss,
)

if custom_args.stop_after_n_steps is not None:
    trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
trainer.callback_handler.remove_callback(transformers.integrations.integration_utils.WandbCallback)

In [None]:
trainer.train()