In [None]:
!pip install sentence-transformers datasets torch transformers accelerate

In [None]:
from datasets import load_dataset

def load_nq_german(data_file = "./data/ng_german.jsonl.gz"):
    # Load the JSONL file as a dataset
    dataset = (
        load_dataset("json", data_files=data_file, split="train",num_proc=8)
        .remove_columns(["query", "answer"])
        .rename_column("question_de", "query")
        .rename_column("answer_de", "answer")
    )
    dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
    return dataset_dict


In [None]:
import logging
import random

import numpy
import torch
#from torch import mps  # noqa: F401
#torch.mps.device = mps
from datasets import Dataset

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

In [None]:
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)

In [None]:
# Feel free to adjust these variables:
use_prompts = True
include_prompts_in_pooling = True

# 1. Load a model to finetune with 2. (Optional) model card data
base_model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

In [None]:
model = SentenceTransformer(
    base_model_name,
    #tokenizer_kwargs={"max_seq_length": 512},
    model_card_data=SentenceTransformerModelCardData(
        language="de",
        license="apache-2.0",
        model_name=f"{base_model_name} trained on german Natural Questions pairs",
    ),
).to(torch.bfloat16)

In [None]:
model.set_pooling_include_prompt(include_prompts_in_pooling)

In [None]:
# 2. (Optional) Define prompts
if use_prompts:
    query_prompt = "query: "
    corpus_prompt = "document: "
    prompts = {
        "query": query_prompt,
        "answer": corpus_prompt,
    }

In [None]:
# 3. Load a dataset to finetune on
dataset_dict = load_nq_german()
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]

In [None]:
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) 
#loss = MultipleNegativesRankingLoss(model) # <- this does work with mps (Apple Silicon)


In [None]:
# 5. (Optional) Specify training arguments
run_name = "nq-german-" + base_model_name.split("/")[-1]
if use_prompts:
    run_name += "-prompts"
if not include_prompts_in_pooling:
    run_name += "-exclude-pooling-prompts"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=4e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.5,
    save_strategy="steps",
    save_steps=0.5,
    save_total_limit=2,
    logging_steps=5,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
    prompts=prompts if use_prompts else None,
)

In [None]:
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    #evaluator=dev_evaluator,
)
trainer.train()

In [None]:
# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")