In [39]:
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, TaskType, get_peft_model
import sys, os, copy

# be able to import from src
sys.path.append(os.path.abspath(".."))

from src.models.causal_lm import NMTModel

### Load model

In [2]:
# NOTE: replace with accurate model directory, preferably not newer than 2023
model_dir = "../data/qwen0-5b"
out_dir = "../data/out"

# Load model
model = NMTModel(model_dir, "cpu")

`torch_dtype` is deprecated! Use `dtype` instead!


In [3]:
# prompt format adapted from "How Good Are GPT Models at Machine Translation? A Comprehensive Evaluation"
example_prompts = [
    "Translate this from English to Japanese: English: Would you like something to eat? Japanese: ",
    "Translate this from English to German: English: Would you like something to eat? German: ",
]

model.prompt_batch(example_prompts)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


['あなたは何か食べたいですか？\n', 'Wollen Sie etwas essen?\n']

## Load data

In [4]:
from torch.utils.data import Dataset, DataLoader
import os
from typing import Tuple, Optional
from functools import partial

In [5]:
IGNORE_TOKEN = -100 # TODO: set in loss function

def full_lang_name(abbr: str):
    """ 
    Returns the full name of the language from its abbrieviation.
    """

    NAMES = {
        "es": "Spanish",
        "fr": "French",
        "pt": "Portuguese",
        "de": "German",
        "it": "Italian",
        "ru": "Russian",
        "en": "English",
    }

    return NAMES[abbr]

class Medline(Dataset):
    def __init__(self, lang_from: str, lang_to: str, folder: str):
        """
        Loads biomedical dataset from the medline corpus 2022.
        Samples will be in the language specified by `lang_from`
        and labels in the language `lang_to`. One language must be 'en' (English),
        the other one of {'es' (Spanish), 'fr' (French), 'pt' (Portuguese), 'de' (German), 'it' (Italian), 'ru' (Russian)}.

        Reads from the directory {folder}/en_{other language} (e.g., wmt22/en_pt). Assumes file names within that directory 
        follow the following convention: {file id}_{language}.txt (e.g., for file ID 120 we need files 120_en.txt and 120_pt.txt for english to/from portuguese).
        """

        VALID_LANGS = { "es", "fr", "pt", "de", "it", "ru", "en" }
        
        assert lang_from in VALID_LANGS, f"Specified language '{lang_from}' is not valid! (must be one of {VALID_LANGS})"
        assert lang_to in VALID_LANGS, f"Specified language '{lang_to}' is not valid! (must be one of {VALID_LANGS})"
        assert lang_from == "en" or lang_to == "en", "One of the languages must be english!"
        assert lang_from != lang_to, "The from and to language may not be the same!"

        # the language that is not english
        other_lang = lang_from if lang_from != "en" else lang_to

        self.data_dir = os.path.join(folder, f"en_{other_lang}")

        # load file IDs
        self.ids = []
        for file in os.listdir(self.data_dir):
            self.ids.append(file.split("_")[0])

        self.lang_from = lang_from
        self.lang_to = lang_to

    def __len__(self):
        return len(self.ids)

    def read_file(self, idx, lang) -> str:
        """
        Tries to read the contents of a language file. 
        # Errors
        * If the file corresponding to `idx` and `lang` does not exist.
        """

        path = os.path.join(self.data_dir, f"{idx}_{lang}.txt")

        with open(path, "r") as file:
            return file.read().rstrip()
    
    def __getitem__(self, index) -> Tuple[str, str]:
        """
        Fetches a sample and target.
        """

        idx = self.ids[index]

        source = self.read_file(idx, self.lang_from)
        target = self.read_file(idx, self.lang_to)

        return source, target
    
dset = Medline("de", "en", "../data/wmt22")

In [None]:
def get_translation_prompt_skeleton(lang_from: str, lang_to: str) -> str:
    """
    Returns a translation prompt skeleton with the source yet to be plugged in.
    The format is adapted from "How Good Are GPT Models at Machine Translation? A Comprehensive Evaluation".
    """
    return f"Translate this from {lang_from} to {lang_to}: {lang_from}: {{}} {lang_to}: "

def collate_fn(batch, tokenizer, prompt_form: str):
    """
    Returns a dictionary containing input_ids, labels, and attention_mask entries 
    by taking a batch of string source target pairs, formatting them into a translation prompt,
    and tokenizing.
    """  

    source = [prompt_form.format(s[0]) for s in batch]
    target = [s[1] for s in batch]

    source_toks_raw = tokenizer(source, add_special_tokens=False)["input_ids"]
    target_toks_raw = tokenizer(target, add_special_tokens=False)["input_ids"]

    # build tokens
    bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
    eos = [tokenizer.eos_token_id]

    source_toks = [bos + s + t + eos for s, t in zip(source_toks_raw, target_toks_raw)]
    toks = tokenizer.pad({ "input_ids": source_toks }, padding=True, return_tensors="pt")

    # for labels, set prompt tokens to IGNORE_TOKEN as we don't want to compute loss over those
    toks["labels"] = copy.deepcopy(toks["input_ids"])
    
    for i in range(len(toks["labels"])):
        toks["labels"][i][:len(source_toks_raw[i])] = IGNORE_TOKEN
    return toks
    

prompt_form = get_translation_prompt_skeleton(full_lang_name(dset.lang_from), full_lang_name(dset.lang_to))

dloader = DataLoader(
    dset,
    batch_size=16,
    collate_fn=partial(collate_fn, tokenizer=model.tokenizer, prompt_form=prompt_form)
)

In [66]:
next(iter(dloader))

{'input_ids': tensor([[ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        ...,
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643],
        [ 27473,    419,    504,  ..., 151643, 151643, 151643]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        ...,
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
        [  -100,   -100,   -100,  ..., 151643, 151643, 151643],
       

### PEFT

In [6]:
# load finetuneable model
# TODO: find good parameters
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model_peft = get_peft_model(model.model, peft_config)

model_peft.print_trainable_parameters()

trainable params: 540,672 || all params: 494,573,440 || trainable%: 0.1093


In [None]:
# TODO: find good parameters
training_args = TrainingArguments(
    output_dir=out_dir,
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# TODO: find good parameters
# TODO: override get_train_dataloader, get_eval_dataloader, and get_test_dataloader
# TODO: set ignore_index in the loss function
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=...,
    eval_dataset=...,
    processing_class=model.tokenizer,
    data_collator=...,
    compute_metrics=...,
)

trainer.train()