In [None]:
!git clone https://dev:dtKN5sX9We7pw1soPB19@gitlab.lrz.de/josh-o/leichte-sprache-corpus.git

Cloning into 'leichte-sprache-corpus'...
remote: Enumerating objects: 236, done.[K
remote: Counting objects: 100% (26/26), done.[K
remote: Compressing objects: 100% (26/26), done.[K


In [None]:
#%%capture
!pip install transformers==4.25.1
!pip install sentencepiece

!pip install textstat

In [None]:
#set up readability metrics
import torch
import textstat
from transformers import set_seed

device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 42
set_seed(seed)
textstat.set_lang("de")

In [None]:
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBartConfig, MBartForCausalLM

model_path = "facebook/mbart-large-cc25"

#mbart_config = MBartConfig.from_pretrained(model_path)
#mbart = CustomMBartForConditionalGeneration(mbart_config)
#mbart.from_pretrained(model_path)

tokenizer = MBartTokenizer.from_pretrained(model_path)
model = MBartForCausalLM.from_pretrained(model_path)

tokenizer.tgt_lang = "de_DE"
tokenizer.src_lang = "de_DE"

#freeze all
for param in model.parameters():
  param.requires_grad = False

for layer in model.model.decoder.layers:
  for param in layer.self_attn.parameters():
    param.requires_grad = True

#unfreeze batchnorms
for module in model.modules():
  if isinstance(module, torch.nn.LayerNorm):
    for param in module.parameters():
      param.requires_grad = True

In [None]:
#TODO replace this code cell by python import

from abc import ABC, abstractmethod
from typing import Iterable
import torch.utils.data
import unicodedata
import random


class AbstractDataset(torch.utils.data.Dataset, ABC):

    def __init__(self, text_dataframe, stride_length, tokenizer):
        """
        text_dataframe: pandas dataframe with columns topic, phrase
        """
        assert((text_dataframe.columns.values == ['topic','phrase']).all())
        self.texts = text_dataframe

        text_list = [ unicodedata.normalize("NFC",s)+tokenizer.eos_token for s in list(self.texts['phrase'].values)]

        self.stride_length = stride_length

        self.encodings = tokenizer(
          text_list,
          truncation=True,
          max_length=1024,
          #stride = stride_length,
          return_special_tokens_mask=True,
          return_overflowing_tokens=True,
        )

    def get_source(self,idx) -> str:
        """
        Returns the source/topic of the requested item
        idx: index of a dataset item

        :return: str - the items original source
        """
        idx = self.encodings['overflow_to_sample_mapping'][idx]
        return self.get_name() + " -> " + self.texts.iloc[idx]['topic']

    def evaluate(self):
        """
        Evaluates the dataset on given metrics

        :return: pandas dataframe - summary of some metrics
        """

        #TODO replace by our metrics
        self.texts['fre'] = self.texts['phrase'].apply(lambda x: textstat.flesch_reading_ease(x))
        return self.texts.describe()

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        sequence_length = random.randint(128, 1024)
        noise = torch.randn(sequence_length, 1024)
        item['encoder_hidden_states'] = noise
        return item

    def __len__(self) -> int:
        """
        Returns number of samples in data set

        :return: int - number of samples in data set
        """
        return len(self.encodings['input_ids'])

    @abstractmethod
    def get_name(self) -> str:
        """
        Returns the name of the data set

        :return: str - name of the data set
        """
        pass

    @abstractmethod
    def get_columns(self) -> Iterable[str]:
        """
        Returns the names of all columns that the data set contains

        :return: list - names of the columns that are available
        """
        pass


class CombinedDataset(torch.utils.data.ConcatDataset):

    def __init__(self, datasets: Iterable[AbstractDataset]):          
        super(CombinedDataset, self).__init__(datasets)

    def get_names(self) -> Iterable[str]:
        """
        Returns a list with the names of all data set that are contained in this combined data set

        :return: list - names of data sets in the data set collection
        """

        return [ds.get_name() for ds in self.datasets]

    def get_summary(self) -> str:
        total_items = 0
        individual_items = {}
        for dataset in self.datasets:
          individual_items[dataset.get_name()] = len(dataset)
          total_items += len(dataset)

        for key in individual_items.keys():
          individual_items[key] = "{:.2f}%".format((individual_items[key]/total_items)*100)
        
        return f"Dataset contains {total_items} items {individual_items}"

In [None]:
import pandas as pd

class NewsData(AbstractDataset):
    def __init__(self, name, csv_file, stride_length,tokenizer):
        phrases = pd.read_csv(csv_file).fillna('text')
        texts = phrases.sort_values(['phrase_number']).groupby(['topic'])['phrase'].apply(' '.join).reset_index()
        self.name = name
        super().__init__(texts, stride_length,tokenizer)

    def get_name(self) -> str:
      return self.name

    def get_columns(self) -> Iterable[str]:
      return self.texts.columns

class HurrakiData(AbstractDataset):
    def __init__(self, csv_file, stride_length,tokenizer, remove_useless = True):
        phrases = pd.read_csv(csv_file)
        if remove_useless:
          phrases = phrases.loc[phrases['useful'] == "YES"]
        phrases = phrases.drop(columns=['useful'])

        texts = phrases.sort_values(['phrase_number']).groupby(['topic'])['phrase'].apply(' '.join).reset_index()
        super().__init__(texts, stride_length,tokenizer)

    def get_name(self) -> str:
      return "Hurraki"

    def get_columns(self) -> Iterable[str]:
      return self.texts.columns

class SyntheticData(AbstractDataset):
    def __init__(self, csv_file, stride_length,tokenizer, remove_useless = True):
        phrases = pd.read_csv(csv_file)

        super().__init__(phrases, stride_length,tokenizer)

    def get_name(self) -> str:
      return "Synthetic"

    def get_columns(self) -> Iterable[str]:
      return self.texts.columns


In [None]:
stride_length = 64
PREFIX = "/content/leichte-sprache-corpus/monolingual/"

#dataset_klexikon          = NewsData("klexikon",PREFIX + "klexikon.csv",context_length, stride_length,input_tokenizer,output_tokenizer)

dataset_nachrichtenleicht = NewsData("NachrichtenLeicht",PREFIX + "nachrichtenleicht.csv", stride_length,tokenizer)
dataset_ndr               = NewsData("NDR",PREFIX + "ndr.csv", stride_length,tokenizer)
dataset_einfachstars      = NewsData("einfachstars",PREFIX + "einfachstars.csv", stride_length,tokenizer)
dataset_hda               = NewsData("hda",PREFIX + "hda_sprachtechnologie.csv", stride_length,tokenizer)
dataset_lebenshilfe       = NewsData("lebenshilfe",PREFIX + "lebenshilfe.csv", stride_length,tokenizer)
dataset_hurraki           = NewsData("hurraki",PREFIX + "hurraki.csv", stride_length,tokenizer)
dataset_kurier            = NewsData("kurier",PREFIX + "kurier.csv", stride_length,tokenizer)

dataset = CombinedDataset([dataset_nachrichtenleicht, 
                           dataset_hurraki, 
                           dataset_ndr, 
                           dataset_einfachstars,
                           dataset_hda,
                           dataset_lebenshilfe,
                           dataset_kurier,
                           ])

generator = torch.Generator().manual_seed(42)

test_val_length = int(.1*len(dataset))
train_length = len(dataset) - (int(.1*len(dataset))*2)
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_length, test_val_length, test_val_length], generator=generator)

dataset.get_summary()

In [None]:
##1 epoch

#5.027800 	4.941916 # frozen first 6 layers + labelsmoothing 0.2
#2.517400 	2.386357 # nothing frozen + no labelsmoothing ppl 17.5
#2.589700 	2.418603 # attentions frozen + no labelsmoothing ppl 18.3
#3.932400 	3.754726 # attentions frozen + labelsmoothing 0.1 ppl 18.7
#3.890600 	3.746634 # attentions frozen + labelsmoothing 0.1 + contrastive ppl 18.03

##2 epochs

#3.738600 	3.648037 # attentions frozen + labelsmoothing 0.1 ppl 19.19

In [None]:
from transformers import DataCollatorForLanguageModeling
import numpy as np

class CustomCollator(DataCollatorForLanguageModeling):
  def __call__(self, features, return_tensors=None):

    attentions = [feature["encoder_hidden_states"] for feature in features] if "encoder_hidden_states" in features[0].keys() else None

    if attentions is not None:
      #align width
      max_width = max(len(a[:,0]) for a in attentions)
      if self.pad_to_multiple_of is not None:
          max_width = (
              (max_width + self.pad_to_multiple_of - 1)
              // self.pad_to_multiple_of
              * self.pad_to_multiple_of
          )

      padding_side = self.tokenizer.padding_side
      for feature in features:
          remainder = [0] *  1024 
          remainder = [remainder] * (max_width - len(feature["encoder_hidden_states"][:,0]))
          if (max_width - len(feature["encoder_hidden_states"][:,0])) == 0:
            continue
          if padding_side == "right":
              feature["encoder_hidden_states"] = np.concatenate([feature["encoder_hidden_states"], remainder], axis=0).astype(np.float32)
          else:
              feature["encoder_hidden_states"] = np.concatenate([remainder, feature["encoder_hidden_states"]], axis=0).astype(np.float32)


    return super().__call__(features, return_tensors)

data_collator = CustomCollator(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8)

In [None]:
from transformers import TrainingArguments, Trainer

#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8)

#Finetuning for one epoch on all data
training_args = TrainingArguments(
    num_train_epochs=1,
    output_dir="./results",
    evaluation_strategy="steps",
    save_strategy='epoch',
    learning_rate=1e-4, #hyperparamater
    weight_decay=0.01,  #hyperparamater
    per_device_train_batch_size=1, 
    per_device_eval_batch_size=2, 
    gradient_accumulation_steps=16,
    warmup_steps=200,
    logging_steps=200,
    fp16=True,
    label_smoothing_factor=0.1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=data_collator,
    #compute_metrics=compute_metrics,
)

#trainer.evaluate()
trainer.train()

In [None]:
import textwrap
#show some examples

input = ["In der Türkei "]
encoding = tokenizer(text_target=input, return_tensors = "pt")#.to(device)
noise = torch.randn(100, 1024)
encoding['encoder_hidden_states'] = noise.to(device)

encoding['input_ids'] = torch.roll(encoding['input_ids'], shifts=1)[:,:-1]
encoding['attention_mask'] = encoding['attention_mask'][:,:-1]

print(tokenizer.batch_decode(encoding['input_ids']))

encoding = { k: v.to(device) for k, v in encoding.items()}

simple_texts = model.generate(**encoding,
                              max_length=20,
                              #repetition_penalty=1.4,
                              num_beams=3,
                              )


print("\nReading Ease: higher = better\n")

for text in tokenizer.batch_decode(simple_texts, skip_special_tokens=False):
  print(f"Flesch Reading Ease: {textstat.flesch_reading_ease(text)}\n")
  print(textwrap.fill(text, 130), '[...]')

['de_DE In der Türkei']

Reading Ease: higher = better

Flesch Reading Ease: 89.6

de_DE In der Türkei gibt es immer mehr Konflikte. In der Türkei gibt es immer mehr Krieg</s> [...]


In [None]:
import matplotlib.pyplot as plt

layer_idx = -1
plt.imshow(out.cross_attentions[layer_idx][0][0].detach().numpy(), aspect="auto", cmap="viridis")
plt.colorbar()
plt.show()

#Calculate Perplexity

In [None]:
from tqdm import tqdm
import pandas as pd

def calculate_perplexity(model, encodings):

  max_length = 512#model.config.n_positions
  stride = 256

  nlls = []
  for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
      begin_loc = max(i + stride - max_length, 0)
      end_loc = min(i + stride, encodings.input_ids.size(1))
      trg_len = end_loc - i  # may be different from stride on last loop
      input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
      target_ids = input_ids.clone()
      target_ids[:, :-trg_len] = -100

      neg_log_likelihood = None

      with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

      nlls.append(neg_log_likelihood)

  ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
  return ppl

PREFIX = "/content/leichte-sprache-corpus/aligned/mdr/"

df = pd.read_csv(PREFIX  + "mdr_aligned_news.csv")

simple_text = "\n\n".join([unicodedata.normalize("NFC",s) for s in df.dropna(subset=['simple_phrase'])['simple_phrase'].values.tolist()])
simple_encodings = tokenizer(simple_text, return_tensors="pt")

normal_text = "\n\n".join([unicodedata.normalize("NFC",s) for s in df.dropna(subset=['normal_phrase'])['normal_phrase'].values.tolist()])
normal_encodings = tokenizer(normal_text, return_tensors="pt")

klexikon = pd.read_csv( "/content/leichte-sprache-corpus/monolingual/klexikon.csv")
klexikon_text = klexikon.sort_values(['phrase_number']).groupby(['topic'])['phrase'].apply(' '.join).reset_index()
klexikon_text = "\n\n".join([ unicodedata.normalize("NFC",s) for s in list(klexikon_text['phrase'].values)])
klexikon_encodings = tokenizer(normal_text, return_tensors="pt")

normal_ppl = calculate_perplexity(model.to(device).eval(), normal_encodings).item()
simple_ppl = calculate_perplexity(model.to(device).eval(), simple_encodings).item()
klexikon_ppl = calculate_perplexity(model.to(device).eval(), klexikon_encodings).item()

print()
print("normal:",normal_ppl)
print("simple:",simple_ppl)
print("klexikon:",klexikon_ppl)

In [None]:
#perplexity simple model klexikon 24.5984
#           normal model klexikon 19.0644

#Save Model

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
model.push_to_hub("mbart-decoder-easy", commit_message="Trained with cross-attention (gaussian noise)")

Configuration saved in /tmp/tmp3vcwdijp/config.json
Model weights saved in /tmp/tmp3vcwdijp/pytorch_model.bin
Uploading the following files to josh-oo/mbart-decoder-easy: pytorch_model.bin,config.json


pytorch_model.bin:   0%|          | 0.00/1.83G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/josh-oo/mbart-decoder-easy/commit/8c785939db9cf0851da9ea9003805bd78f7ec6cd', commit_message='Trained with cross-attention (gaussian noise)', commit_description='', oid='8c785939db9cf0851da9ea9003805bd78f7ec6cd', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
tokenizer.push_to_hub("mbart-decoder-easy", commit_message="Initial upload")