In [None]:
%%capture
#download the data
!git clone https://dev:dtKN5sX9We7pw1soPB19@gitlab.lrz.de/josh-o/leichte-sprache-corpus.git

In [None]:
%%capture
#install dependencies (April 14, 2023)
#pytorch 2.0.0+cu118
#Python 3.9.16
!pip install transformers==4.28.0 
!pip install sentencepiece==0.1.98
!pip install datasets==2.11.0

In [None]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"

PREFIX = "/content/leichte-sprache-corpus/aligned/20min/"
PREFIX_KURIER = "/content/leichte-sprache-corpus/aligned/kurier/"
PREFIX_AUGMENTED = "/content/leichte-sprache-corpus/aligned/20min/augmented/"
PREFIX_AUGMENTED_KURIER = "/content/leichte-sprache-corpus/aligned/kurier/augmented/"

model_path, revision = ("facebook/mbart-large-cc25", "57cecec5a3185d3ec7b3021a53093cf96835a634")

In [None]:
#experiments

#20min dropout tests
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", None, 0.1, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", None, 0.3, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", None, 0.8, 0.0)

#20in data augmentation test
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv" ,PREFIX_AUGMENTED + "simple_noise.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", PREFIX_AUGMENTED + "bart_noise.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", PREFIX_AUGMENTED + "inputs_back_google.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", PREFIX_AUGMENTED + "inputs_back_google_simple_noise.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX + "20min_aligned_train.csv", PREFIX_AUGMENTED + "inputs_english_deepl.csv", 0.0, 0.0)

#Kurier dropout tests
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", None, 0.1, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", None, 0.3, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", None, 0.8, 0.0)

#Kurier data augmentation test
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv" ,PREFIX_AUGMENTED_KURIER + "simple_noise.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", PREFIX_AUGMENTED_KURIER + "bart_noise.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", PREFIX_AUGMENTED_KURIER + "inputs_back_google.csv", 0.0, 0.0)
#data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", PREFIX_AUGMENTED_KURIER + "inputs_back_google_simple_noise.csv", 0.0, 0.0)
data, augmentation_data, encoder_dropout, decoder_dropout = (PREFIX_KURIER + "kurier_aligned_train.csv", PREFIX_AUGMENTED_KURIER + "inputs_english_deepl.csv", 0.0, 0.0)

#Model

In [None]:
from transformers import MBartForConditionalGeneration, MBartTokenizerFast, MBartConfig

model_config = MBartConfig.from_pretrained(model_path)
model_config.dropout = 0.0

model = MBartForConditionalGeneration.from_pretrained(model_path, config=model_config, revision=revision)

tokenizer = MBartTokenizerFast.from_pretrained(model_path, src_lang="de_DE", tgt_lang="de_DE")

# set decoding params
model.config.decoder_start_token_id=250003
model.config.forced_bos_token_id=0
model.config.max_length = 1024

#freeze all
for param in model.parameters():
    param.requires_grad = False

#make cross attention trainable
for layer in model.model.decoder.layers:
  for param in layer.encoder_attn.parameters():
    param.requires_grad = True
  for param in layer.encoder_attn_layer_norm.parameters():
    param.requires_grad = True

#unfreeze batchnorms
for module in model.modules():
  if isinstance(module, torch.nn.LayerNorm):
    for param in module.parameters():
      param.requires_grad = True

def set_dropout(model, p_encoder, p_decoder):
  
  model.model.encoder.dropout = p_encoder
  for layer in model.model.encoder.layers:
    layer.dropout = p_encoder

  model.model.decoder.dropout = p_decoder
  for layer in model.model.decoder.layers:
    layer.dropout = p_decoder

set_dropout(model, p_encoder=encoder_dropout, p_decoder=decoder_dropout)

#Data

In [None]:
from datasets import load_dataset,concatenate_datasets, Features, Value
import unicodedata

def normalize(text):
  text['normal_phrase'] = "<s>" + unicodedata.normalize("NFC",text['normal_phrase'].strip())
  text['simple_phrase'] = "<s>" + unicodedata.normalize("NFC",text['simple_phrase'].strip())
  return text

def tokenize(text, input_tokenizer, output_tokenizer, max_input_length):
  inputs = input_tokenizer(text["normal_phrase"], return_tensors="np")
  labels = output_tokenizer(text["simple_phrase"], return_tensors="np", truncation=True, max_length=max_input_length)
  inputs['labels'] = labels['input_ids']
  return inputs

def count(text):
  text['length'] = len(text['input_ids']) * len(text['labels'])
  return text

def get_dataset(data_files, input_tokenizer, output_tokenizer, name=None, augmentation_file=None,max_input_length=None, augmentation_tokenizer=None):
  features = Features({'normal_phrase': Value('string'), 'simple_phrase': Value('string')})

  data = load_dataset("csv",name=name, data_files=data_files, features=features)
  data = data.map(normalize, num_proc=4)
  data = data.map(lambda rows: tokenize(rows, input_tokenizer, output_tokenizer, max_input_length), batched=True)
  if "train" in data:
    data['train'] = data['train'].map(count, num_proc=4)
    data = data.remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length']])
  else:
    data = data.remove_columns([column for column in data.column_names['test'] if column not in ['labels','input_ids','attention_mask','length']])

  if augmentation_file is not None:
    if augmentation_tokenizer is None:
      augmentation_tokenizer = input_tokenizer
    #add augmented input to the train dataset if given
    data_a = load_dataset("csv", data_files=augmentation_file, features=features)
    data_a = data_a.map(normalize, num_proc=4)
    data_a = data_a.map(lambda row: augmentation_tokenizer(row["normal_phrase"], return_tensors="np", truncation=True, max_length=max_input_length), batched=True)
    data_a = data_a.rename_column("input_ids", "augmented_ids")
    data_a = data_a.rename_column("attention_mask", "augmented_mask")
    data['train'] = concatenate_datasets([data['train'], data_a['train']], axis=1)
    data['train'] = data['train'].remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length','augmented_ids','augmented_mask']])

  if max_input_length is not None:
    data = data.filter(lambda example: len(example["input_ids"]) < max_input_length)

  return data

In [None]:
#train on 20min
data_train = get_dataset({'train': data}, tokenizer, tokenizer, augmentation_file=augmentation_data, max_input_length=model.config.max_length)

#Evaluation

In [None]:
from transformers import DataCollatorForSeq2Seq

class CustomCollator(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):

        augmented_ids = [feature["augmented_ids"] for feature in features] if "augmented_ids" in features[0].keys() else None
        augmented_mask = [feature["augmented_mask"] for feature in features] if "augmented_mask" in features[0].keys() else None

        if augmented_ids is not None and augmented_mask is not None:
            #process augmentation data
            temp_features = []
            for feature in features:
                temp_feature = {}
                temp_feature['input_ids'] = feature.pop("augmented_ids")
                temp_feature['attention_mask'] = feature.pop("augmented_mask")
                temp_features.append(temp_feature)

            temp_features = self.__call__(temp_features)
            for i, feature in enumerate(features):
                feature["augmented_ids"] = temp_features["input_ids"][i].tolist()
                feature["augmented_mask"] = temp_features["attention_mask"][i].tolist()

        return super().__call__(features, return_tensors)

data_collator = CustomCollator(tokenizer=tokenizer, model=model, pad_to_multiple_of=8)

In [None]:
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader

all_loss = []

train_dataloader = DataLoader(data_train['train'], batch_size=2, collate_fn=data_collator)
model.to(device)

loss_fn = torch.nn.CosineSimilarity(dim=-1, eps=1e-08)

with torch.no_grad():
  
  for item in tqdm(train_dataloader):

    test_item = {}
    test_item['input_ids'] = item['input_ids'].to(device)
    test_item['attention_mask'] = item['attention_mask'].to(device)
    test_item['labels'] = item['labels'].to(device)

    augmented_item =test_item
    if augmentation_data is not None:
      augmented_item = {}
      augmented_item['input_ids'] = item['augmented_ids'].to(device)
      augmented_item['attention_mask'] = item['augmented_mask'].to(device)
      augmented_item['labels'] = item['labels'].to(device)

    model.eval()
    outputs_original = model(**test_item, output_hidden_states=True)
    if augmentation_data is None:
      model.train()
    outputs_augmented  = model(**augmented_item, output_hidden_states=True)

    mask = test_item['labels'] != -100
    mask = mask

    loss = loss_fn(outputs_original.decoder_hidden_states[-1], outputs_augmented.decoder_hidden_states[-1])

    loss = torch.sum(loss * mask, dim=-1) / torch.sum(mask, dim=-1)

    all_loss.extend(loss.detach().tolist())

losses = torch.tensor(all_loss + all_loss)
print("Mean: ", losses.mean())
print("STD: ", losses.std())

##Auto Disconnect from Colab to Save Credits

In [None]:
from google.colab import runtime
runtime.unassign()