In [None]:
%%capture
#download the data
!git clone https://dev:dtKN5sX9We7pw1soPB19@gitlab.lrz.de/josh-o/leichte-sprache-corpus.git

In [None]:
%%capture
#install dependencies (April 14, 2023)
#pytorch 2.0.0+cu118
#Python 3.9.16
!pip install transformers==4.28.0 
!pip install sentencepiece==0.1.98
!pip install sacremoses==0.0.53
!pip install evaluate==0.3.0
!pip install sacrebleu==2.3.1

In [None]:
import os
import torch
from transformers import set_seed

device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 42
set_seed(seed) # no direct effect on text generation

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"

PREFIX = "../../leichte-sprache-corpus/aligned/20min/"
PREFIX_KURIER = "../../leichte-sprache-corpus/aligned/kurier/"
PREFIX_AUGMENTED = "../../leichte-sprache-corpus/aligned/20min/augmented/"
PREFIX_AUGMENTED_KURIER = "../../leichte-sprache-corpus/aligned/kurier/augmented/"

mbart_path, mbart_revision = ("facebook/mbart-large-cc25", "57cecec5a3185d3ec7b3021a53093cf96835a634")

In [None]:
#experiments

decoder_path, custom_decoder_path, decoder_revision, augmentation, pretrained_weights, dropout_encoder, dropout_decoder, n_best = None, None, None, None, None, None, None, None

#finetuned mBART decoder experiments

#decoder_path, decoder_revision, data, experiment_name = ("josh-oo/mbart-decoder-easy", "4d0d3be96da2a7327c8e38434c20a413b464b9a2", "Kurier", "ft_no_inputs")
#decoder_path, decoder_revision, data, experiment_name = ("josh-oo/mbart-decoder-easy", "8c785939db9cf0851da9ea9003805bd78f7ec6cd", "Kurier", "ft_gaussian")
#decoder_path, decoder_revision, data, experiment_name = ("josh-oo/mbart-decoder-easy", "0afac431259d36a796e9a08f428c70d537fd579d", "Kurier", "ft_bart_encodings")

#custom decoder experiments (using the custom gpt version)
#custom_decoder_path, decoder_revision, data, experiment_name = ("josh-oo/german-gpt2-easy", "f47cb790634678a65be9df09491edf154bd51f7c", "20min", "custom_gpt")
#custom_decoder_path, decoder_revision, data, pretrained_weights, experiment_name = ("josh-oo/german-gpt2-easy", "f47cb790634678a65be9df09491edf154bd51f7c", "20min", ("josh-oo/calibrated-decoder", "28295adb39f2f8c82d981e53e519dd74959190cf"), "custom_pretrained_one2one")
#custom_decoder_path, decoder_revision, data, pretrained_weights, experiment_name = ("josh-oo/german-gpt2-easy", "f47cb790634678a65be9df09491edf154bd51f7c", "20min", ("josh-oo/calibrated-decoder", "7bf004553a1f7d02e33dfece6b5901a20b873b69"), "custom_pretrained_bart_noise")

#dataset size experiments
#n_best, data, experiment_name = (0.75, "20min", "train_on_best_data_75p")
#n_best, data, experiment_name = (0.50, "20min", "train_on_best_data_50p")
#n_best, data, experiment_name = (0.25, "20min", "train_on_best_data_25p")
#n_best, data, experiment_name = (0.05, "20min", "train_on_best_data_05p")
#n_best, data, experiment_name = (None, "20min", "train_on_best_data_100p")

#data augmentation 20min
#dropout_encoder, dropout_decoder, data, experiment_name = (0.0, 0.1, "20min", "da_dropout_0_0")
#dropout_encoder, dropout_decoder, data, experiment_name = (0.1, 0.1, "20min", "da_dropout_0_1")
#dropout_encoder, dropout_decoder, data, experiment_name = (0.3, 0.1, "20min", "da_dropout_0_3")
#dropout_encoder, dropout_decoder, data, experiment_name = (0.3, 0.1, "20min", "da_dropout_0_3")

#dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "20min", "inputs_simple_noise.csv", "da_simple_noise")
#dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "20min", "inputs_bart_noise.csv",  "da_bart_noise")
#dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "20min", "inputs_back_google.csv", "da_bt")
#dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "20min", "inputs_back_google_simple_noise.csv", "da_bt_noise")
#dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "20min", "inputs_english_deepl.csv", "da_english")

#data augmentation Kurier
#TODO dropout_encoder, dropout_decoder, data, experiment_name = (0.0, 0.1, "Kurier", "da_dropout_0_0")
#dropout_encoder, dropout_decoder, data, experiment_name = (0.1, 0.1, "Kurier", "da_dropout_0_1")
#TODO dropout_encoder, dropout_decoder, data, experiment_name = (0.3, 0.1, "Kurier", "da_dropout_0_3")
#TODO dropout_encoder, dropout_decoder, data, experiment_name = (0.3, 0.1, "Kurier", "da_dropout_0_3")

dropout_encoder, dropout_decoder, data, augmentation, experiment_name = (0.0, 0.1, "Kurier", "inputs_simple_noise.csv", "kurier_da_simple_noise")
#...

# Hints

After defining all hyperparameters in the cell above, you can run all cells consecutively.

# Model

## mBART with Custom GPT-2 decoder
Ignore this section if you want to train a pure mBART model

In [None]:
#mBART with gpt-2 decoder:

if custom_decoder_path is not None:
    from transformers import EncoderDecoderModel, AutoModelForCausalLM, MBartModel, GPT2Tokenizer, MBartTokenizerFast

    #prepare output_tokenizer

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
      outputs = token_ids_0 + [self.eos_token_id]
      return outputs

    GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens

    output_tokenizer = GPT2Tokenizer.from_pretrained(custom_decoder_path)

    output_tokenizer.pad_token_id = 1
    output_tokenizer.bos_token_id = 0
    output_tokenizer.eos_token_id = 2

    input_tokenizer = MBartTokenizerFast.from_pretrained(mbart_path)
    if hasattr(input_tokenizer, "src_lang"):
      input_tokenizer.src_lang = "de_DE"

    mbart = MBartModel.from_pretrained(mbart_path, revision=mbart_revision)
    decoder = AutoModelForCausalLM.from_pretrained(custom_decoder_path, revision=decoder_revision)

    model = EncoderDecoderModel(encoder=mbart.encoder,decoder=decoder)

    # set decoding params
    model.config.decoder_start_token_id = output_tokenizer.bos_token_id
    model.config.eos_token_id = output_tokenizer.eos_token_id
    model.config.pad_token_id = 1
    model.config.max_length = 1024

    #freeze all
    for param in model.parameters():
        param.requires_grad = False

    #make cross attention trainable
    for module in model.decoder.transformer.h:
      for param in module.crossattention.parameters():
        param.requires_grad = True
      for param in module.ln_cross_attn.parameters():
        param.requires_grad = True

    if hasattr(model,'enc_to_dec_proj'):
      model.enc_to_dec_proj.requires_grad = True

    #unfreeze batchnorms
    for module in model.modules():
      if isinstance(module, torch.nn.LayerNorm):
        for param in module.parameters():
          param.requires_grad = True

### Load Pre-trained weights

In [None]:
if custom_decoder_path is not None and pretrained_weights is not None:
  pretrained = EncoderDecoderModel.from_pretrained(pretrained_weights[0], revision=pretrained_weights[1])
  pretrained.save_pretrained("temp")

  #update the original weights
  all_states = model.state_dict()
  update_states = torch.load("/temp/pytorch_model.bin", map_location=device)
  all_states.update(update_states)
  model.load_state_dict(all_states)

## mBART
Ignore this section if you want to train a model with custom decoder

In [None]:
if custom_decoder_path is None:
    from transformers import MBartForConditionalGeneration, MBartTokenizerFast, MBartConfig

    model = MBartForConditionalGeneration.from_pretrained(mbart_path, revision=mbart_revision)

    tokenizer = MBartTokenizerFast.from_pretrained(mbart_path, src_lang="de_DE", tgt_lang="de_DE")
    input_tokenizer, output_tokenizer = tokenizer, tokenizer

    # set decoding params
    model.config.decoder_start_token_id=250003
    model.config.forced_bos_token_id=0
    model.config.max_length = 1024

    #freeze all
    for param in model.parameters():
        param.requires_grad = False

    #make cross attention trainable
    for layer in model.model.decoder.layers:
      for param in layer.encoder_attn.parameters():
        param.requires_grad = True
      for param in layer.encoder_attn_layer_norm.parameters():
        param.requires_grad = True

    #unfreeze batchnorms
    for module in model.modules():
      if isinstance(module, torch.nn.LayerNorm):
        for param in module.parameters():
          param.requires_grad = True

### Load Custom mMBART Decoder

In [None]:
if custom_decoder_path is None and decoder_path is not None and decoder_revision is not None:
  #load the finetuned mBART decoders weights
  from transformers import MBartForCausalLM

  decoder = MBartForCausalLM.from_pretrained(decoder_path, revision=decoder_revision)
  decoder.save_pretrained("temp")

  #update the decoder weights of the full mBART
  all_states = model.state_dict()
  update_states = torch.load("/content/temp/pytorch_model.bin", map_location=device)
  all_states.update(update_states)
  model.load_state_dict(all_states)

# Model Preparation

In [None]:
def set_dropout(model, p_encoder, p_decoder):
  
  model.get_encoder().dropout = p_encoder
  for layer in model.get_encoder().layers:
    layer.dropout = p_encoder

  model.get_decoder().dropout = p_decoder
  for layer in model.get_decoder().layers:
    layer.dropout = p_decoder

if dropout_encoder is not None and dropout_decoder is not None:
  set_dropout(model, p_encoder=dropout_encoder, p_decoder=dropout_decoder)

# Data

In [None]:
from datasets import load_dataset,concatenate_datasets, Features, Value
import unicodedata
import numpy as np
import pandas as pd

def normalize(text):
  text['normal_phrase'] = "<s>" + unicodedata.normalize("NFC",text['normal_phrase'].strip())
  text['simple_phrase'] = "<s>" + unicodedata.normalize("NFC",text['simple_phrase'].strip())
  return text

def tokenize(text, input_tokenizer, output_tokenizer, max_input_length):
  inputs = input_tokenizer(text["normal_phrase"], return_tensors="np")
  labels = output_tokenizer(text["simple_phrase"], return_tensors="np", truncation=True, max_length=max_input_length)
  inputs['labels'] = labels['input_ids']
  return inputs

def count(text):
  #calculate the length token which is used to group the data samples
  #we use len(input) * len(output) as it models the maximum GPU memory consumption the best
  #(we want to have the data sample with the highest memory consumption at the first place to force early Out-Of-Memory issues)
  text['length'] = len(text['input_ids']) * len(text['labels'])
  return text

def get_dataset(data_files, input_tokenizer, output_tokenizer, name=None, augmentation_file=None,max_input_length=None, augmentation_tokenizer=None, seed=None, n_best=None):
  features = Features({'normal_phrase': Value('string'), 'simple_phrase': Value('string')})
  data = load_dataset("csv",name=name, data_files=data_files, features=features)
    
  if n_best is not None:
    print("\nTest\n")
    #select only the n best train samples according to a precomputed loss
    df = pd.read_csv(data_files['train'])
    data['train'] = data['train'].add_column("loss",df['loss'])
    
    #calculate eachs samples deviation from the mean loss to exclude outliers that are either too hard or too easy according to their loss
    df['deviation'] = (df['loss'] - df['loss'].mean()).abs()
    n = int(len(df.index) * n_best)
    df = df.nsmallest(n, 'deviation', keep='first')
    min_pre_loss = df['loss'].min()
    max_pre_loss = df['loss'].max()
    
    data['train'] = data['train'].filter(lambda example: example["loss"] < max_pre_loss and example["loss"] > min_pre_loss)
    data['train'] = data['train'].remove_columns(['loss'])
    
  data = data.map(normalize, num_proc=4)
  data = data.map(lambda rows: tokenize(rows, input_tokenizer, output_tokenizer, max_input_length), batched=True)
  if "train" in data:
    data['train'] = data['train'].map(count, num_proc=4)
    data = data.remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length']])
  else:
    data = data.remove_columns([column for column in data.column_names['test'] if column not in ['labels','input_ids','attention_mask','length']])

  if augmentation_file is not None:
    if augmentation_tokenizer is None:
      augmentation_tokenizer = input_tokenizer
    #add augmented input to the train dataset if given
    data_a = load_dataset("csv", data_files=augmentation_file, features=features)
    data_a = data_a.map(normalize, num_proc=4)
    data_a = data_a.map(lambda row: augmentation_tokenizer(row["normal_phrase"], return_tensors="np", truncation=True, max_length=max_input_length), batched=True)
    data_a = data_a.rename_column("input_ids", "augmented_ids")
    data_a = data_a.rename_column("attention_mask", "augmented_mask")
    data['train'] = concatenate_datasets([data['train'], data_a['train']], axis=1)
    data['train'] = data['train'].remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length','augmented_ids','augmented_mask']])

    #we want to randomize the order of original and augmented inputs to avoid the case of first train all original and then all augmented version
    #add an random permutation of indicators, indicating whether augmented inputs should be used in even or odd epochs
    #use a seed to keep it reproducable
    data_length = len(data['train'])
    indicators = np.ones(data_length, dtype='bool')
    indicators[:data_length//2] = False
    np.random.seed(seed)
    indicators = np.random.permutation(indicators).tolist()
    indicators

    data['train'] = data['train'].add_column("indicator",indicators)

  if max_input_length is not None:
    data = data.filter(lambda example: len(example["input_ids"]) < max_input_length)

  return data

In [None]:
if data == "20min":
  #train on 20min
  augmentation_tokenizer = None
  if augmentation:
    augmentation = PREFIX_AUGMENTED + augmentation
    if "english" in augmentation:
      augmentation_tokenizer = None #TODO load the corresponding tokenizer

  data_files_20_min = {'train': PREFIX + "20min_aligned_train.csv", 'val': PREFIX + "20min_aligned_dev.csv", 'test': PREFIX + "20min_aligned_test.csv"}
  data_files_kurier = {'val': PREFIX_KURIER + "kurier_aligned_dev.csv", 'test': PREFIX_KURIER + "kurier_aligned_test.csv"}

  train_name = "20min"
  test_name = "KURIER"
  steps_to_train = 2000

  data_train = get_dataset(data_files_20_min, input_tokenizer, output_tokenizer, train_name, augmentation, augmentation_tokenizer=augmentation_tokenizer, max_input_length=model.config.max_length,seed=seed, n_best=n_best)
  data_test  = get_dataset(data_files_kurier, input_tokenizer, output_tokenizer, test_name, max_input_length=model.config.max_length,seed=seed)

In [None]:
if data == "Kurier":
  #train on kurier

  augmentation_tokenizer = None
  if augmentation:
    augmentation = PREFIX_AUGMENTED_KURIER + augmentation
    if "english" in augmentation:
      augmentation_tokenizer = None #TODO load the corresponding tokenizer
    
  data_files_kurier = {'train': PREFIX_KURIER + "kurier_aligned_train.csv",'val': PREFIX_KURIER + "kurier_aligned_dev.csv", 'test': PREFIX_KURIER + "kurier_aligned_test.csv"}
  data_files_20_min = {'val': PREFIX + "20min_aligned_dev.csv", 'test': PREFIX + "20min_aligned_test.csv"}

  train_name = "KURIER"
  test_name = "20min"
  steps_to_train = 1000

  data_train = get_dataset(data_files_kurier, input_tokenizer, input_tokenizer, train_name,augmentation, augmentation_tokenizer=augmentation_tokenizer, max_input_length=model.config.max_length,seed=seed, n_best=n_best)
  data_test  = get_dataset(data_files_20_min, input_tokenizer, input_tokenizer, test_name,  max_input_length=model.config.max_length,seed=seed)

# Trainingstuff

In [None]:
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
class CustomCollator(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):

        augmented_ids = [feature["augmented_ids"] for feature in features] if "augmented_ids" in features[0].keys() else None
        augmented_mask = [feature["augmented_mask"] for feature in features] if "augmented_mask" in features[0].keys() else None

        if augmented_ids is not None and augmented_mask is not None:
            #process augmentation data
            temp_features = []
            for feature in features:
                temp_feature = {}
                temp_feature['input_ids'] = feature.pop("augmented_ids")
                temp_feature['attention_mask'] = feature.pop("augmented_mask")
                temp_features.append(temp_feature)

            temp_features = self.__call__(temp_features)
            for i, feature in enumerate(features):
                feature["augmented_ids"] = temp_features["input_ids"][i].tolist()
                feature["augmented_mask"] = temp_features["attention_mask"][i].tolist()

        return super().__call__(features, return_tensors)

In [None]:
from transformers import Seq2SeqTrainer

class AugmentationTrainer(Seq2SeqTrainer):
    def __init__(self, **kwargs):
        
        self.current_epoch = 0
        self.even_epoch = torch.tensor(True, device=model.device)
      
        super().__init__(**kwargs)

    def evaluate(self,eval_dataset = None,ignore_keys = None,metric_key_prefix = "eval",**gen_kwargs):
        return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, num_beams=1, do_sample=False)#, top_k=3, penalty_alpha=0.6)

    def predict(self,test_dataset,ignore_keys = None,metric_key_prefix = "test",**gen_kwargs):
        return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, num_beams=3, do_sample=False)

    def compute_loss(self, model, inputs, return_outputs=False):

        if self.state.epoch > self.current_epoch:
          self.current_epoch = self.state.epoch
          self.even_epoch = ~self.even_epoch

        length = inputs.pop('length', None)
        indicator = inputs.pop('indicator', None)
        augmented_ids = inputs.pop("augmented_ids", None)
        augmented_mask = inputs.pop("augmented_mask", None)

        original_ids = inputs.pop("input_ids")
        original_mask = inputs.pop("attention_mask")

        #TODO handle different paddings for batchsizes > 1
        #inputs['input_ids'] = torch.where(indicator == self.even_epoch, original_ids, augmented_ids)
        #inputs['attention_mask'] = torch.where(indicator == self.even_epoch, original_mask, augmented_mask)

        if augmented_ids is not None and indicator is not None and (indicator == self.even_epoch).all():
          inputs['input_ids'] = augmented_ids
          inputs['attention_mask'] = augmented_mask
        else:
          inputs['input_ids'] = original_ids
          inputs['attention_mask'] = original_mask

        return super().compute_loss(model, inputs, return_outputs=return_outputs)

In [None]:
import evaluate
from collections.abc import Iterable

def compute_translation_metrics(input_tokenizer, output_tokenizer, pred, control_tokens):

    input_ids = pred.inputs
    label_ids = pred.label_ids
    pred_ids = pred.predictions

    input_ids[input_ids == -100] = input_tokenizer.pad_token_id
    label_ids[label_ids == -100] = output_tokenizer.pad_token_id
    pred_ids[pred_ids == -100] = output_tokenizer.pad_token_id

    input_str_list = input_tokenizer.batch_decode(input_ids, skip_special_tokens=True,
                                                  clean_up_tokenization_spaces=False)
    pred_str_list = output_tokenizer.batch_decode(pred_ids, skip_special_tokens=True,
                                                  clean_up_tokenization_spaces=False)
    label_str_list = output_tokenizer.batch_decode(label_ids, skip_special_tokens=True,
                                                   clean_up_tokenization_spaces=False)
    
    if control_tokens == True:
      for i in range(0, len(input_str_list)):
        input_str_list[i] = input_str_list[i].split(' ', 1)[1]

    label_str_list = [[label] for label in label_str_list]

    sari = evaluate.load("sari")
    bleu = evaluate.load("bleu")

    sari_score = sari.compute(sources=input_str_list, predictions=pred_str_list, references=label_str_list)
    bleu_score = bleu.compute(predictions=pred_str_list, references=label_str_list)

    translation_result = {
        'sari':sari_score['sari'],
        'bleu':bleu_score['bleu']*100
    }

    return {key: sum(value) / len(value) if isinstance(value, Iterable) else value for (key, value) in
            translation_result.items()}

compute_metrics = lambda pred: compute_translation_metrics(input_tokenizer, output_tokenizer, pred, control_tokens=False)

# Train Model

In [None]:
data_collator = CustomCollator(tokenizer=input_tokenizer, model=model, pad_to_multiple_of=8)

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    include_inputs_for_metrics=True,
    generation_max_length=1024,
    
    num_train_epochs=2,
    warmup_ratio=0.05,
    #warmup_steps=100,
    #max_steps=steps_to_train,
    output_dir="../results",
    #evaluation_strategy="steps",
    evaluation_strategy="epoch",
    save_strategy='no',
    learning_rate=3e-5, 
    weight_decay=0.01, 
    per_device_eval_batch_size=4, 
    per_device_train_batch_size=1, 
    gradient_accumulation_steps=16,
    fp16=True,
    #logging_steps= steps_to_train // 4,
    group_by_length=True,
    seed=seed,
    data_seed=seed,
    remove_unused_columns=False,
    dataloader_num_workers=2,
    optim='adamw_torch',
)

trainer = AugmentationTrainer(
    model=model,
    args=training_args,
    compute_metrics = compute_metrics,
    train_dataset=data_train['train'],
    eval_dataset=data_train['val'],
    data_collator=data_collator,
)
trainer.train()

# Upload the Model

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
model.push_to_hub("mbart-ts", commit_message="Trained on Kurier [Simple Noise Decoder-Dropout 0.1] (2 epochs)")

In [None]:
preds = trainer.predict(test_dataset=data_train['test'])
print(preds.metrics)

## Auto Disconnect from Colab to Save Credits

In [None]:
from google.colab import runtime
runtime.unassign()