In [1]:
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, TaskType, get_peft_model
import torch
from torch.utils.data import DataLoader, random_split
import sys, os, copy
from functools import partial

# be able to import from src
sys.path.append(os.path.abspath(".."))

from src.models.causal_lm import NMTModel
from src.data import Medline, collate_translations, get_translation_prompt_skeleton, full_lang_name

  from .autonotebook import tqdm as notebook_tqdm


### Load model

In [2]:
# NOTE: replace with accurate model directory, preferably not newer than 2023
model_dir = "../data/qwen0-5b"
out_dir = "../data/out"

# load model
model = NMTModel(model_dir, "cpu")

`torch_dtype` is deprecated! Use `dtype` instead!


In [3]:
# prompt format adapted from "How Good Are GPT Models at Machine Translation? A Comprehensive Evaluation"
example_prompts = [
    "Translate this from English to Japanese: English: Would you like something to eat? Japanese: ",
    "Translate this from English to German: English: Would you like something to eat? German: ",
]

model.prompt_batch(example_prompts)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


['あなたは何か食べたいですか？\n', 'Wollen Sie etwas essen?\n']

## Load data

In [4]:
dset = Medline("de", "en", "../data/wmt22")

train_dset, test_dset = random_split(dset, [0.8, 0.2])

In [5]:
prompt_form = get_translation_prompt_skeleton(full_lang_name(dset.lang_from), full_lang_name(dset.lang_to))

dloader = DataLoader(
    dset,
    batch_size=16,
    collate_fn=partial(collate_translations, tokenizer=model.tokenizer, prompt_form=prompt_form)
)

In [6]:
next(iter(dloader))

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        ...,
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        ...,
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
       

### PEFT

In [7]:
# load finetuneable model
# TODO: find good parameters
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model_peft = get_peft_model(model.model, peft_config)

model_peft.print_trainable_parameters()

trainable params: 540,672 || all params: 494,573,440 || trainable%: 0.1093


In [8]:
class EncoderTrainer(Trainer):
    """
    Custom trainer for optimizing encoder-only NMT transformer models.

    During training and evaluation, it provides source sentences (possibly wrapped in a prompt to indicate the type of task to the model)
    and optimizes the model for predicting correct target sentences. In other words, loss or metrics are not calculated over the
    source prompt, only over the target sentence.
    """

    def __init__(self, prompt_skeleton: str, tokenizer, **kwargs):
        super().__init__(**kwargs)

        self.collate_fn = partial(collate_translations, tokenizer=tokenizer, prompt_form=prompt_skeleton)

    def get_train_dataloader(self):
        return DataLoader( # TODO: use accelerator.prepare?
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            collate_fn=self.collate_fn
        )
    
    def get_eval_dataloader(self, eval_dataset = None):
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        
        return DataLoader(
            self.eval_dataset[eval_dataset] if isinstance(eval_dataset, str) else eval_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.collate_fn
        )
    
    def get_test_dataloader(self, test_dataset):
        return DataLoader(
            test_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.collate_fn
        )

In [None]:
training_args = TrainingArguments(
    output_dir=out_dir,
    learning_rate=1e-3, # TODO: finetune
    # TODO: make as big as GPU permits
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,

    # prevent source labels to be included in loss
    label_smoothing_factor=0.0,
)

# TODO: set ignore_index in the loss function
trainer = EncoderTrainer(
    # prompt used for aiding the model to make translations
    prompt_skeleton=get_translation_prompt_skeleton(full_lang_name(dset.lang_from), full_lang_name(dset.lang_to)),
    tokenizer=model.tokenizer,

    model=model.model,
    args=training_args,
    train_dataset=train_dset,
    eval_dataset=test_dset,
    processing_class=model.tokenizer,
)