In [None]:
%%capture
#download the data
!git clone https://dev:dtKN5sX9We7pw1soPB19@gitlab.lrz.de/josh-o/leichte-sprache-corpus.git

In [None]:
%%capture
#install dependencies (April 14, 2023)
#pytorch 2.0.0+cu118
#Python 3.9.16
!pip install transformers==4.28.0 
!pip install sentencepiece==0.1.98
!pip install datasets==2.11.0

In [None]:
import torch
from transformers import set_seed

device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 42
set_seed(seed)

PREFIX = "../../leichte-sprache-corpus/monolingual/bart_noise/"

# Model

In [None]:
#choose your desired input mode

NO_INPUT = "no_input"
GAUSSIAN_NOISE = "gaussian_noise_input"
BART_NOISE = "bart_noise_input"

input_mode = GAUSSIAN_NOISE#BART_NOISE

In [None]:
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBartForCausalLM

model_path = "facebook/mbart-large-cc25"
model_type = MBartForCausalLM
if input_mode == BART_NOISE:
  model_type = MBartForConditionalGeneration

tokenizer = MBartTokenizer.from_pretrained(model_path, tgt_lang = "de_DE", src_lang = "de_DE")

model = model_type.from_pretrained(model_path)

#freeze all
for param in model.parameters():
  param.requires_grad = False

#unfreeze self attention
for layer in model.model.decoder.layers:
  for param in layer.self_attn.parameters():
    param.requires_grad = True

#unfreeze batchnorms
for module in model.modules():
  if isinstance(module, torch.nn.LayerNorm):
    for param in module.parameters():
      param.requires_grad = True

# Data

In [None]:
from datasets import load_dataset,concatenate_datasets, Features, Value
import unicodedata
import numpy as np
import random

def normalize(text):
  text['normal_phrase'] = "<s>" + unicodedata.normalize("NFC",text['normal_phrase'].strip())
  text['simple_phrase'] = "<s>" + unicodedata.normalize("NFC",text['simple_phrase'].strip())
  return text

def tokenize(text, tokenizer, max_input_length):
  inputs = tokenizer(text["normal_phrase"], return_tensors="np")
  labels = tokenizer(text["simple_phrase"], return_tensors="np", truncation=True, max_length=max_input_length)
  inputs['labels'] = labels['input_ids']
  return inputs

def count(text):
  #calculate the length token which is used to group the data samples
  #(we want to have the data sample with the highest memory consumption at the first place to force early Out-Of-Memory issues)
  text['length'] = len(text['labels'])
  return text

def add_gaussian_noise(row):
  sequence_length = random.randint(128, 1024)
  row['random_sequence_length'] = sequence_length
  return row

def add_decoder_atention_mask(row):
  row['attention_mask'] = [1]*len(row['input_ids'])
  return row

def get_dataset(data_files, tokenizer, input_mode, name=None, max_input_length=None):
  features = Features({'normal_phrase': Value('string'), 'simple_phrase': Value('string')})

  data = load_dataset("csv",name=name, data_files=data_files, features=features)['train'].train_test_split(test_size=0.1)
  data = data.map(normalize, num_proc=4)
  data = data.map(lambda rows: tokenize(rows, tokenizer, max_input_length), batched=True)

  data['train'] = data['train'].map(count, num_proc=4)
  data = data.remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length', 'random_sequence_length']])
  if max_input_length is not None:
    data = data.filter(lambda example: len(example["input_ids"]) < max_input_length)

  if input_mode == GAUSSIAN_NOISE:
    data = data.map(lambda rows: add_gaussian_noise(rows))

  if input_mode != BART_NOISE:
    #decoder only
    data = data.remove_columns([column for column in data.column_names['train'] if column not in ['labels','length', 'random_sequence_length']])
    data = data.rename_column("labels", "input_ids")
    data = data.map(lambda rows: add_decoder_atention_mask(rows))

  return data

In [None]:
data_files_monolingual = [PREFIX + "nachrichtenleicht_noise.csv",
                          PREFIX + "ndr_noise.csv",
                          PREFIX + "einfachstars_noise.csv",
                          PREFIX + "hda_sprachtechnologie_noise.csv",
                          PREFIX + "lebenshilfe_noise.csv",
                          PREFIX + "hurraki_noise.csv",
                          PREFIX + "kurier_noise.csv"]

data_train = get_dataset(data_files_monolingual, tokenizer, input_mode, "20min", max_input_length=model.config.max_length)

# Training

In [None]:
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
import numpy as np

data_collator_type = DataCollatorForLanguageModeling
data_collator_args = {'mlm':False, 'pad_to_multiple_of':8}
if input_mode == BART_NOISE:
  data_collator_type = DataCollatorForSeq2Seq
  data_collator_args = {'model':model, 'pad_to_multiple_of':8}

class CustomCollator(data_collator_type):
  def __call__(self, features, return_tensors=None):

    encoder_hidden_states = [torch.randn(feature.pop("random_sequence_length"), 1024) for feature in features] if "random_sequence_length" in features[0].keys() else None
    if encoder_hidden_states:
      for i, feature in enumerate(features):
        feature['encoder_hidden_states'] = encoder_hidden_states[i].numpy()
    
    for feature in features:
      feature.pop('length',None)
    

    if encoder_hidden_states is not None:
      #align width
      max_width = max(len(a[:,0]) for a in encoder_hidden_states)
      if self.pad_to_multiple_of is not None:
          max_width = (
              (max_width + self.pad_to_multiple_of - 1)
              // self.pad_to_multiple_of
              * self.pad_to_multiple_of
          )

      padding_side = self.tokenizer.padding_side
      for feature in features:
          remainder = [0] *  1024 
          remainder = [remainder] * (max_width - len(feature["encoder_hidden_states"][:,0]))
          if (max_width - len(feature["encoder_hidden_states"][:,0])) == 0:
            continue
          if padding_side == "right":
              feature["encoder_hidden_states"] = np.concatenate([feature["encoder_hidden_states"], remainder], axis=0).astype(np.float32)
          else:
              feature["encoder_hidden_states"] = np.concatenate([remainder, feature["encoder_hidden_states"]], axis=0).astype(np.float32)


    return super().__call__(features, return_tensors)

data_collator = CustomCollator(tokenizer=tokenizer, **data_collator_args)

In [None]:
from transformers import TrainingArguments, Trainer, Seq2SeqTrainingArguments, Seq2SeqTrainer

training_arguments, trainer = TrainingArguments, Trainer
if input_mode == BART_NOISE:
  training_arguments, trainer = Seq2SeqTrainingArguments, Seq2SeqTrainer

#finetune for one epoch on all data
training_args = training_arguments(
    num_train_epochs=1,
    output_dir="./results",
    evaluation_strategy="steps",
    save_strategy='epoch',
    learning_rate=1e-4,
    weight_decay=0.01,
    per_device_train_batch_size=1, 
    per_device_eval_batch_size=2, 
    gradient_accumulation_steps=16,
    warmup_steps=200,
    logging_steps=20,
    fp16=True,
    label_smoothing_factor=0.1 ,
    group_by_length=True,
    seed=seed,
    data_seed=seed,
    remove_unused_columns=False,
    dataloader_num_workers=2,
    optim='adamw_torch',
)

trainer = trainer(
    model=model,
    args=training_args,
    train_dataset=data_train['train'],
    eval_dataset=data_train['test'],
    data_collator=data_collator,
)

trainer.train()

# Save Model

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#convert to causal lm
model.save_pretrained("temp")
model = MBartForCausalLM.from_pretrained("temp")

In [None]:
model.push_to_hub("mbart-decoder-easy", commit_message="Trained with cross-attention (gaussian noise)")

## Auto Disconnect from Colab to Save Credits

In [None]:
from google.colab import runtime
runtime.unassign()