In [8]:
import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
import transformers
from tqdm import tqdm
import os

In [9]:
from huggingface_hub import notebook_login

# Handle lighting AI studio
if '/teamspace' in os.getcwd():
    os.chdir('/teamspace/studios/this_studio/llm2vec-da')
    print(os.getcwd())

/teamspace/studios/this_studio/llm2vec-da


In [7]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [10]:
from transformers import HfArgumentParser
from llm2vec_da.arguments import SimCSEModelArguments, SimCSEDataTrainingArguments, TrainingArguments, SimCSECustomArguments

simcse_parser = HfArgumentParser(
        (SimCSEModelArguments, SimCSEDataTrainingArguments, TrainingArguments, SimCSECustomArguments)
    )

model_args, data_args, training_args, custom_args = simcse_parser.parse_json_file(
        "configs/simcse/MetaLlama3-swe-dk-wiki-scandi.json"
    )
accelerator = Accelerator()
transformers.set_seed(training_args.seed)
if training_args.gradient_checkpointing:
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

### Load the model

In [None]:
from llm2vec_da.data_utils import PairedDataset, load_raw_datasets, _load_from_files

In [None]:
datasets = _load_from_files(data_args, model_args)
train_dataset = PairedDataset(datasets['train'])
valid_dataset = PairedDataset(datasets['validation'])
train_examples = [train_dataset[i] 
                  for i in tqdm(range(len(train_dataset)),desc="Loading train examples...",disable=not accelerator.is_main_process)
]
validation_examples = [
    valid_dataset[i]
    for i in tqdm(
        range(len(valid_dataset)),
        desc="Loading train examples...",
        disable=not accelerator.is_main_process,
    )
]

### Loading the model

In [None]:
from llm2vec_da import LLM2Vec

model = LLM2Vec.from_pretrained(
    base_model_name_or_path=model_args.model_name_or_path,
    enable_bidirectional=model_args.bidirectional,
    peft_model_name_or_path=model_args.peft_model_name_or_path,
    merge_peft=True,
    pooling_mode=model_args.pooling_mode,
    max_length=model_args.max_seq_length,
    torch_dtype=getattr(torch, model_args.torch_dtype),
    attn_implementation=model_args.attn_implementation,
    attention_dropout=custom_args.simcse_dropout,
    load_in_8bit=True
)
model

## Training

In [None]:
from llm2vec.loss.utils import load_loss
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
train_loss

In [None]:
from bg2vec.training import SimCSEDefaultCollator
data_collator = SimCSEDefaultCollator(model.tokenize)

In [None]:
from bg2vec.training import SimCSETrainer, StopTrainingCallback
trainer = SimCSETrainer(
    model=model,
    args=training_args,
    train_dataset=train_examples,
    eval_dataset=validation_examples,
    data_collator=data_collator,
    tokenizer=model.tokenizer,
    loss_function=train_loss,
)

if custom_args.stop_after_n_steps is not None:
    trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
trainer.callback_handler.remove_callback(transformers.integrations.integration_utils.WandbCallback)

In [None]:
trainer.train()