In [52]:
!pip install sentence-transformers datasets tensorboardX peft



In [25]:
import logging
import random

import numpy
import torch
#from torch import mps  # noqa: F401
#torch.mps.device = mps
from datasets import Dataset, load_dataset

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

In [26]:
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)

In [27]:
# Feel free to adjust these variables:
use_prompts = True
include_prompts_in_pooling = True

In [28]:
base_model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"


In [29]:
model = SentenceTransformer(
    base_model_name,
    #tokenizer_kwargs={"max_seq_length": 512},
    model_card_data=SentenceTransformerModelCardData(
        language="de",
        license="apache-2.0",
        model_name=f"{base_model_name} trained on german Natural Questions pairs",
    ),
).to(torch.bfloat16)

In [30]:
model.set_pooling_include_prompt(include_prompts_in_pooling)

In [55]:
from peft import LoraModel, LoraConfig, TaskType

In [56]:
peft_config = LoraConfig(
    task_type= TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=64,
    lora_alpha=128,
    lora_dropout=0.1,
)
model.add_adapter(peft_config)

# 2. (Optional) Define prompts
if use_prompts:
    query_prompt = "query: "
    corpus_prompt = "document: "
    prompts = {
        "query": query_prompt,
        "answer": corpus_prompt,
    }

In [57]:
natural_questions_german = load_dataset("oliverguhr/natural-questions-german", split="train")

natural_questions_german = natural_questions_german.remove_columns(["answer", "query"]) # delete the english language columns
natural_questions_german = natural_questions_german.rename_column("query_de", "query").rename_column("answer_de", "answer")


natural_questions_german = natural_questions_german.train_test_split(test_size=1000, seed=12)

train_dataset: Dataset = natural_questions_german["train"]
eval_dataset: Dataset = natural_questions_german["test"]

In [58]:
train_dataset

Dataset({
    features: ['query', 'answer'],
    num_rows: 99231
})

In [59]:
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=32) # <- this does not work with mps
#loss = MultipleNegativesRankingLoss(model)

In [60]:
# 5. (Optional) Specify training arguments
run_name = "german-nq-" + base_model_name.split("/")[-1]
if use_prompts:
    run_name += "-prompts"
if not include_prompts_in_pooling:
    run_name += "-exclude-pooling-prompts"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=4e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.5,
    save_strategy="steps",
    save_steps=0.5,
    save_total_limit=2,
    logging_steps=5,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    seed=12,
    prompts=prompts if use_prompts else None,
    report_to="tensorboard",
)

In [None]:
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    #evaluator=dev_evaluator,
)
trainer.train()

Step,Training Loss,Validation Loss
