In [None]:
%%capture
#download the data
!git clone https://dev:dtKN5sX9We7pw1soPB19@gitlab.lrz.de/josh-o/leichte-sprache-corpus.git

In [None]:
%%capture
#install dependencies (April 14, 2023)
#pytorch 2.0.0+cu118
#Python 3.9.16
!pip install transformers==4.28.0 
!pip install sentencepiece==0.1.98
!pip install pytokenizations==0.8.4
!pip install datasets==2.11.0

In [None]:
import torch
from transformers import GPT2Tokenizer
from transformers import EncoderDecoderModel, AutoModelForCausalLM
from transformers import MBartForConditionalGeneration, MBartTokenizerFast, MBartConfig
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import set_seed

import random
import tokenizations
import copy

import numpy as np
import matplotlib.pyplot as plt

device = "cuda:0" if torch.cuda.is_available() else "cpu"

seed = 42
set_seed(seed) # no direct effect on text generation

PREFIX = "../../leichte-sprache-corpus/aligned/20min/"
PREFIX_AUGMENTED = "../../leichte-sprache-corpus/aligned/20min/augmented/"

encoder_path = "facebook/mbart-large-cc25"
decoder_path = "josh-oo/german-gpt2-easy"

# Model

Choose the model you want to train be uncommenting it

## mBART with Custom GPT-2 decoder

In [None]:
#mBART with gpt-2 decoder:

from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from typing import Optional, Tuple, Union

#modified GPT2Block to access the cross attentions keys and values
#copied from huggingface transformers modeling_gpt2.py
class CachableGPT2Block(GPT2Block):

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = list(outputs)
            outputs[0] = (outputs[0][0],outputs[0][1]) + cross_attn_outputs[1]
            outputs = tuple(outputs)
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]
        return outputs  # hidden_states, present, (attentions, cross_attentions)


class DummyEncoder(torch.nn.Module):
  def __init__(self, config):
    self.config = config
    super().__init__()

#prepare output_tokenizer

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  outputs = token_ids_0 + [self.eos_token_id]
  return outputs

GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens

output_tokenizer = GPT2Tokenizer.from_pretrained(decoder_path)

output_tokenizer.pad_token_id = 1
output_tokenizer.bos_token_id = 0
output_tokenizer.eos_token_id = 2

#output_tokenizers = (output_tokenizer, output_tokenizer_fast)

encoder_config = MBartConfig.from_pretrained(encoder_path)

mbart = MBartForConditionalGeneration.from_pretrained(encoder_path, config=encoder_config)

input_tokenizer = MBartTokenizerFast.from_pretrained(encoder_path)
if hasattr(input_tokenizer, "src_lang"):
  input_tokenizer.src_lang = "de_DE"

decoder = AutoModelForCausalLM.from_pretrained(decoder_path)
#start of fix: cross attention value output
decoder.save_pretrained("tmp")
for i in range(0, len(decoder.transformer.h)):
  decoder.transformer.h[i] = CachableGPT2Block(config=decoder.config, layer_idx=i)
decoder.from_pretrained("tmp")

model = EncoderDecoderModel(encoder=mbart.model.encoder,decoder=decoder)

encoder_model = mbart.model.encoder
teacher = mbart.model.decoder

model.encoder = DummyEncoder(model.encoder.config)

# set decoding params
model.config.decoder_start_token_id = output_tokenizer.bos_token_id
model.config.eos_token_id = output_tokenizer.eos_token_id
model.config.pad_token_id = 1
model.config.max_length = 1024

#freeze all
for param in model.parameters():
    param.requires_grad = False

#make cross attention trainable
for module in model.decoder.transformer.h:
  for param in module.crossattention.parameters():
    param.requires_grad = True
  for param in module.ln_cross_attn.parameters():
    param.requires_grad = True

if hasattr(model,'enc_to_dec_proj'):
  model.enc_to_dec_proj.requires_grad = True

#unfreeze batchnorms
for module in model.modules():
  if isinstance(module, torch.nn.LayerNorm):
    for param in module.parameters():
      param.requires_grad = True

## mBART

In [None]:
#pure mBART:
"""
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBartConfig
import copy

model_config = MBartConfig.from_pretrained(encoder_path)

model = MBartForConditionalGeneration.from_pretrained(encoder_path, config=model_config)

teacher = copy.deepcopy(model.model.decoder)
encoder_model = model.model.encoder
model.model.encoder = None

input_tokenizer = MBartTokenizer.from_pretrained(encoder_path)
if hasattr(input_tokenizer, "src_lang"):
  input_tokenizer.src_lang = "de_DE"

output_tokenizer = MBartTokenizer.from_pretrained(encoder_path)
if hasattr(output_tokenizer, "src_lang"):
  output_tokenizer.src_lang = "de_DE"

# set decoding params
model.config.decoder_start_token_id=250003
model.config.eos_token_id = output_tokenizer.eos_token_id
model.config.pad_token_id = 1
model.config.max_length = 1024

#freeze all
for param in model.parameters():
    param.requires_grad = False

#make cross attention trainable
for layer in model.model.decoder.layers:
  for param in layer.encoder_attn.parameters():
    param.requires_grad = True
  for param in layer.encoder_attn_layer_norm.parameters():
    param.requires_grad = True

#unfreeze batchnorms
for module in model.modules():
  if isinstance(module, torch.nn.LayerNorm):
    for param in module.parameters():
      param.requires_grad = True

#reset cross-attention to random parameters
for layer in model.model.decoder.layers:
    layer.encoder_attn.apply(model._init_weights)
"""

# Data

In [None]:
from datasets import load_dataset,concatenate_datasets, Features, Value
import unicodedata
import numpy as np

def normalize(text):
  text['normal_phrase'] = "<s>" + unicodedata.normalize("NFC",text['normal_phrase'].strip())
  text['simple_phrase'] = "<s>" + unicodedata.normalize("NFC",text['simple_phrase'].strip())
  return text

def tokenize(text, input_tokenizer, output_tokenizer, max_input_length):
  inputs = input_tokenizer(text["normal_phrase"], return_tensors="np")
  labels = output_tokenizer(text["simple_phrase"], return_tensors="np", truncation=True, max_length=max_input_length)
  inputs['labels'] = labels['input_ids']
  return inputs

def count(text):
  #calculate the length token which is used to group the data samples
  #we use len(input) * len(output) as it models the maximum GPU memory consumption the best
  #(we want to have the data sample with the highest memory consumption at the first place to force early Out-Of-Memory issues)
  text['length'] = len(text['input_ids']) * len(text['labels'])
  return text

## knowledge distillation stuff:

def teacher_decoder_inputs(text, teacher_tokenizer, max_input_length):
  decoder_inputs = input_tokenizer(text["simple_phrase"], return_tensors="np", truncation=True, max_length=max_input_length)
  text['teacher_decoder_input_ids'] = decoder_inputs['input_ids']
  return text

def compute_token_transform(input_tokens, output_tokens, interpolate=False):
  in2out, out2in = tokenizations.get_alignments(input_tokens, output_tokens)

  vector = torch.zeros(len(output_tokens), len(input_tokens))

  for i, mapping in enumerate(in2out):
      vector[mapping,i] = 1
  
  if interpolate == False:
    #interpolate rows with no direct token connection (please refer master thesis Figure 4.2 lower branch)
    vector = torch.where(vector > 0.99, vector, 0.0)

  vector = (vector.T / vector.sum(dim=1))
  #division by zero leads to nan values 
  if torch.isnan(vector).any():
    vector = torch.nan_to_num(vector, nan=0.0)
  return vector.T

def compute_all_token_transforms(row, teacher_tokenizer, student_tokenizer, max_input_length=None):
  simple_text = row["simple_phrase"]
  input_tokens = ["<s>"] + teacher_tokenizer.tokenize(simple_text.strip())
  output_tokens = ["<s>"] + student_tokenizer.tokenize(simple_text.strip())

  input_tokens =  ["#" + teacher_tokenizer.convert_tokens_to_string([t]).strip() for t in input_tokens]
  output_tokens = ["#" + student_tokenizer.convert_tokens_to_string([t]).strip() for t in output_tokens]

  transform = compute_token_transform(input_tokens, output_tokens)
  if max_input_length is not None:
    transform = transform[:max_input_length,:max_input_length]
  #convert sparse tensor to make it serializable as the datasets library requires cachable items
  transform = transform.to_sparse()
  row['token_transform_size'] = transform.size()
  row['token_transform_values'] = transform.values()
  row['token_transform_indices'] = transform.indices()
  return row

def get_dataset(data_files, input_tokenizer, output_tokenizer, name=None, max_input_length=None):
  features = Features({'normal_phrase': Value('string'), 'simple_phrase': Value('string')})

  data = load_dataset("csv",name=name, data_files=data_files, features=features)
  data = data.map(normalize, num_proc=4)
  data = data.map(lambda rows: tokenize(rows, input_tokenizer, output_tokenizer, max_input_length), batched=True)

  if type(input_tokenizer) is not type(output_tokenizer):
    data['train'] = data['train'].map(lambda rows: teacher_decoder_inputs(rows, output_tokenizer, max_input_length), batched=True)
    #disable caching as sparse tensors can't be cached and saving dense tensors is to expensive
    data['train'] = data['train'].map(lambda rows: compute_all_token_transforms(rows, input_tokenizer, output_tokenizer,max_input_length), num_proc=4)


  if "train" in data:
    data['train'] = data['train'].map(count, num_proc=4)
    data = data.remove_columns([column for column in data.column_names['train'] if column not in ['labels','input_ids','attention_mask','length', 'teacher_decoder_input_ids', 'token_transform_size', 'token_transform_values', 'token_transform_indices']])
  else:
    data = data.remove_columns([column for column in data.column_names['test'] if column not in ['labels','input_ids','attention_mask','length']])

  if max_input_length is not None:
    data = data.filter(lambda example: len(example["input_ids"]) < max_input_length)

  return data

In [None]:
#choose the right dataset to pre-train:

data_files_20_min = {'train': PREFIX_AUGMENTED + "pretraining.csv", 'val': PREFIX + "20min_aligned_dev.csv"}
#data_files_20_min = {'train': PREFIX + "20min_aligned_train.csv", 'val': PREFIX + "20min_aligned_dev.csv"}

data_train = get_dataset(data_files_20_min, input_tokenizer, output_tokenizer, "20min", max_input_length=model.config.max_length)

#Training

In [None]:
class CustomCollator(DataCollatorForSeq2Seq):
  def __call__(self, features, return_tensors=None):

    transform_i = [feature.pop("token_transform_indices") for feature in features] if "token_transform_indices" in features[0].keys() else None
    transform_v = [feature.pop("token_transform_values") for feature in features] if "token_transform_values" in features[0].keys() else None
    transform_s = [feature.pop("token_transform_size") for feature in features] if "token_transform_size" in features[0].keys() else None
    
    transforms = [torch.sparse_coo_tensor(i, v, s).to_dense() for i,v,s in zip(transform_i, transform_v, transform_s)] if transform_i else None
    if transforms:
      for i, feature in enumerate(features):
        feature['token_transform'] = transforms[i]
    teacher_decoder_input_ids = [feature["teacher_decoder_input_ids"] for feature in features] if "teacher_decoder_input_ids" in features[0].keys() else None

    max_label_length = 0
    if teacher_decoder_input_ids is not None:
      max_label_length = max(len(l) for l in teacher_decoder_input_ids)
      if self.pad_to_multiple_of is not None:
          max_label_length = (
              (max_label_length + self.pad_to_multiple_of - 1)
              // self.pad_to_multiple_of
              * self.pad_to_multiple_of
          )

      padding_side = self.tokenizer.padding_side
      for feature in features:
          remainder = [self.tokenizer.pad_token_id] * (max_label_length - len(feature["teacher_decoder_input_ids"]))
          if isinstance(feature["teacher_decoder_input_ids"], list):
              feature["teacher_decoder_input_ids"] = (
                  feature["teacher_decoder_input_ids"] + remainder if padding_side == "right" else remainder + feature["teacher_decoder_input_ids"]
              )
          elif padding_side == "right":
              feature["teacher_decoder_input_ids"] = np.concatenate([feature["teacher_decoder_input_ids"], remainder]).astype(np.int64)
          else:
              feature["teacher_decoder_input_ids"] = np.concatenate([remainder, feature["teacher_decoder_input_ids"]]).astype(np.int64)

    if transforms is not None:
      #align width

      padding_side = self.tokenizer.padding_side
      for feature in features:
          remainder = [0] *  (max_label_length  - len(feature["token_transform"][0,:]))
          remainder = [remainder] * len(feature["token_transform"][:,0])
          if padding_side == "right":
              feature["token_transform"] = np.concatenate([feature["token_transform"], remainder], axis=1).astype(np.float32)
          else:
              feature["token_transform"] = np.concatenate([remainder, feature["token_transform"]], axis=1).astype(np.float32)

      #align height
      max_height = max(len(a[:,0]) for a in transforms)
      if self.pad_to_multiple_of is not None:
          max_height = (
              (max_height + self.pad_to_multiple_of - 1)
              // self.pad_to_multiple_of
              * self.pad_to_multiple_of
          )

      padding_side = self.tokenizer.padding_side
      for feature in features:
          remainder = [0] * len(feature["token_transform"][0,:])
          remainder = [remainder] * (max_height - len(feature["token_transform"][:,0]))
          if padding_side == "right" and len(remainder) > 0:
              feature["token_transform"] = np.concatenate([feature["token_transform"], remainder], axis=0).astype(np.float32)
          elif len(remainder) > 0:
              feature["token_transform"] = np.concatenate([remainder, feature["token_transform"]], axis=0).astype(np.float32)

    return super().__call__(features, return_tensors)

data_collator = CustomCollator(tokenizer=input_tokenizer, model=model, pad_to_multiple_of=8)

In [None]:
class ReduceAttnMap(torch.nn.Module):
    #adapted from Attention Distillation: self-supervised vision transformer students need more guidance
    def __init__(self, temperature, log_space=False):
      super().__init__()
      self.temperature = temperature
      self.softmax = torch.nn.functional.softmax
      if log_space:
        self.softmax = torch.nn.functional.log_softmax

    def forward(self, attn_map, attn_mask):
      epsilon = 1e-6
      attn_map = torch.add(attn_map, epsilon)
      attn_map = torch.sum(torch.log(attn_map), dim=1)
      attn_map = attn_map.masked_fill(~attn_mask, float('-inf'))
      attn_map = self.softmax(self.temperature * attn_map, dim=-1)
      attn_map = torch.nan_to_num(attn_map)

      return attn_map

class ReduceValueRelations(torch.nn.Module):
    #adapted from MiniLMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers
    def __init__(self, log_space=False, num_relation_heads=8):
      super().__init__()
      self.num_relation_heads = num_relation_heads
      self.softmax = torch.nn.functional.softmax
      if log_space:
        self.softmax = torch.nn.functional.log_softmax

    def forward(self, past_inputs, attn_mask):
      epsilon = 1e-6

      input_shape = past_inputs.shape
      num_heads = input_shape[1]
      attn_head_size = input_shape[-1]

      rel_head_size = num_heads*attn_head_size // self.num_relation_heads

      #merge heads
      past_inputs = past_inputs.permute(0, 2, 1, 3).contiguous()
      new_shape = past_inputs.size()[:-2] + (num_heads * attn_head_size,)
      past_inputs = past_inputs.view(new_shape)

      #split heads
      new_shape = past_inputs.size()[:-1] + (self.num_relation_heads,  rel_head_size)
      past_inputs = past_inputs.view(new_shape)
      past_inputs = past_inputs.permute(0, 2, 1, 3)

      past_inputs = torch.matmul(past_inputs,past_inputs.transpose(-1,-2))
      past_inputs = past_inputs.masked_fill(~attn_mask, float('-inf'))
      past_inputs = self.softmax((rel_head_size ** -0.5) * past_inputs, dim=-1)
      past_inputs = torch.nan_to_num(past_inputs)

      return past_inputs

class DistillationTrainer(Seq2SeqTrainer):
    def __init__(self, attention_alpha=0.5, attention_temperature=1/16, teacher_model=None, encoder_model=None, **kwargs):
      super().__init__(**kwargs)
      
      self.MAX_LAYERS_TO_DISTILL = 2
      self.current_stages = []

      self.kl_loss = torch.nn.KLDivLoss(reduction="none")
      self.reduce_guide_attn_map = ReduceAttnMap(temperature=1/16)
      self.reduce_actual_attn_map = ReduceAttnMap(temperature=attention_temperature, log_space=True)
      self.reduce_guide_value_relations = ReduceValueRelations()
      self.reduce_actual_value_relations = ReduceValueRelations(log_space=True)

      self.attention_alpha = attention_alpha
      self.attention_temperature = attention_temperature
      self.teacher_model = teacher_model
      self.encoder_model = encoder_model


      self._move_model_to_device(self.teacher_model,self.model.device)
      self._move_model_to_device(self.encoder_model,self.model.device)
      self.teacher_model.eval()

      self.decoder_head_mask = torch.ones(12,16, device=self.model.device)
      self.decoder_head_mask = self.decoder_head_mask.bool()

      if hasattr(self.model, "decoder") and hasattr(self.model.decoder.config, "vocab_size"):
         self.decoder_vocab_size = self.model.decoder.config.vocab_size
      else:
         self.decoder_vocab_size =  self.model.config.vocab_size

    def evaluate(self,eval_dataset = None,ignore_keys = None,metric_key_prefix = "eval",**gen_kwargs):
        return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, num_beams=1, do_sample=False)#, top_k=3, penalty_alpha=0.6)

    def predict(self,test_dataset,ignore_keys = None,metric_key_prefix = "test",**gen_kwargs):
      return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, num_beams=3, do_sample=False)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop('labels')
        input_ids = inputs.get('input_ids')
        attention_mask = inputs.get('attention_mask')

        inputs.pop('length', None)

        token_transform = inputs.pop('token_transform', None)

        decoder_input_ids = inputs.pop('decoder_input_ids')
        decoder_attention_mask = labels != -100

        teacher_decoder_input_ids = inputs.pop('teacher_decoder_input_ids', None)
        if teacher_decoder_input_ids is None:
          teacher_decoder_input_ids = decoder_input_ids
          teacher_decoder_attention_mask = decoder_attention_mask
        else:
          teacher_decoder_attention_mask = teacher_decoder_input_ids != 1
          
        # forward computation
        with torch.no_grad():
          encoder_hidden_states = self.encoder_model(**inputs, output_hidden_states=False, output_attentions=False)
        outputs = model(**inputs,decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_hidden_states, output_attentions=True, use_cache=True)#, cross_attn_head_mask=self.decoder_head_mask)
        logits = outputs.logits

        regular_loss = outputs.loss
        if self.label_smoother is not None:
          regular_loss = self.label_smoother(outputs, labels, shift_labels=False)
        else:
          loss_fct = torch.nn.CrossEntropyLoss()
          regular_loss = loss_fct(logits.view(-1, self.decoder_vocab_size), labels.view(-1))

        total_attn_loss = torch.tensor(0.0, device=self.model.device)
        if self.teacher_model is not None:

          teacher_outputs = None
          with torch.no_grad():
            teacher_outputs = self.teacher_model(input_ids=teacher_decoder_input_ids, attention_mask=teacher_decoder_attention_mask, encoder_hidden_states=encoder_hidden_states[0], encoder_attention_mask=attention_mask, output_attentions=True, use_cache=True)
          
            del encoder_hidden_states

            attn_mask = torch.bmm(decoder_attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).bool()
            vr_attn_mask = torch.bmm(attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).bool()
            teacher_attn_mask = torch.bmm(teacher_decoder_attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).bool()

            vr_attn_mask = vr_attn_mask.unsqueeze(1).repeat(1, 8, 1, 1)

          #attention loss

          current_progress = self.state.global_step / (self.state.max_steps +1)
          current_stage = int(current_progress * (len(teacher_outputs['cross_attentions']) -  (self.MAX_LAYERS_TO_DISTILL - 1) - 1))
          current_stages = [ current_stage + i for i in range(0, self.MAX_LAYERS_TO_DISTILL)]

          #distill last layer and two additional layers randomly selected at each training step
          current_stages = random.sample(range(0, len(teacher_outputs['cross_attentions']) - 1), 2) + [-1]
          if current_stages != self.current_stages:
            self.current_stages = current_stages
            #print("Train new layers: ", current_stages)

          for layer_index in current_stages:
            with torch.no_grad():

              #Guide attention distribution
              guide_attn_map = teacher_outputs['cross_attentions'][layer_index]
              guide_attention = self.reduce_guide_attn_map(guide_attn_map, teacher_attn_mask)

              #Guide value relations
              guide_key_values = teacher_outputs['past_key_values'][layer_index]
              
              guide_cross_attn_past_values = guide_key_values[-2:][1]
              guide_vr = self.reduce_guide_value_relations(guide_cross_attn_past_values, vr_attn_mask)
              
              guide_cross_attn_past_keys = guide_key_values[-2:][0]
              guide_kr = self.reduce_guide_value_relations(guide_cross_attn_past_keys, vr_attn_mask)


              if token_transform is not None:
                guide_attention = torch.matmul(token_transform, guide_attention)

                no_match_mask = torch.unsqueeze(token_transform.sum(dim=-1), dim=-1)
                no_match_mask = torch.repeat_interleave(no_match_mask, attn_mask.shape[-1], dim=-1)
                attn_mask = (attn_mask * no_match_mask).bool()

            #Actual attention distribution
            actual_attn_map = outputs['cross_attentions'][layer_index]
            actual_attention = self.reduce_actual_attn_map(actual_attn_map, attn_mask)

            attn_loss = self.kl_loss(actual_attention, guide_attention)
            del actual_attention
            attn_loss = (attn_loss * (attn_mask)).sum(dim=[1,2])
            attn_loss = attn_loss / attn_mask.sum(dim=1)[:,0] #Divide loss by number of tokens
            attn_loss = attn_loss.sum() / guide_attention.size(0) # to have batch_mean

            actual_key_values = outputs['past_key_values'][layer_index]
            #Actual value relation
          
            actual_cross_attn_past_values = actual_key_values[-2:][1]
            actual_vr = self.reduce_actual_value_relations(actual_cross_attn_past_values, vr_attn_mask)

            vr_loss = self.kl_loss(actual_vr, guide_vr)
            del actual_vr
            vr_loss = (vr_loss * vr_attn_mask).sum(dim=[1,2,3])
            vr_loss = vr_loss / vr_attn_mask.sum(dim=2)[:,0,0] #Divide loss by number of tokens
            vr_loss = vr_loss.sum() / (guide_vr.size(0) * guide_vr.size(1))

            #Actual key relation

            actual_cross_attn_past_keys = actual_key_values[-2:][0]
            actual_kr = self.reduce_actual_value_relations(actual_cross_attn_past_keys, vr_attn_mask)

            kr_loss = self.kl_loss(actual_kr, guide_kr)
            del actual_kr
            kr_loss = (kr_loss * vr_attn_mask).sum(dim=[1,2,3])
            kr_loss = kr_loss / vr_attn_mask.sum(dim=2)[:,0,0] #Divide loss by number of tokens
            kr_loss = kr_loss.sum() / (guide_kr.size(0) * guide_kr.size(1))

            del guide_vr
            #del guide_attention
            del guide_kr
            total_attn_loss = total_attn_loss + kr_loss + vr_loss + attn_loss
            #del attn_loss
            del kr_loss
            del vr_loss

        total_attn_loss = total_attn_loss / len(current_stages)

        loss = regular_loss * (1 - self.attention_alpha) + total_attn_loss * self.attention_alpha
        return (loss, logits) if return_outputs else loss

In [None]:
training_args = Seq2SeqTrainingArguments(
    num_train_epochs=3,
    output_dir="/results",
    evaluation_strategy="steps",
    save_strategy='no',
    learning_rate=1e-3, 
    weight_decay=0.01, 
    warmup_steps=100,
    per_device_eval_batch_size=4, 
    per_device_train_batch_size=1, 
    gradient_accumulation_steps=8,
    fp16=True,
    logging_steps=500,
    group_by_length=True,
    seed=seed,
    data_seed=seed,
    remove_unused_columns=False,
    dataloader_num_workers=2,
    optim='adamw_torch',
)
trainer = DistillationTrainer(
    attention_alpha=1.0,
    attention_temperature=1/16,
    encoder_model=encoder_model,
    teacher_model=teacher,
    model=model,
    args=training_args,
    train_dataset=data_train['train'],
    eval_dataset=data_train['val'],
    data_collator=data_collator,
)
trainer.train()

#Visualize Results

In [None]:
kl_loss = torch.nn.KLDivLoss(reduction="none")

test_item = data_collator([data_train['train'][1]])
test_item = { k: v.to(device) for k, v in test_item.items()}

encoder_model.to(device)
model.to(device)
teacher.to(device)

test_item.pop('length', None)
token_transform = test_item.pop('token_transform', None)

decoder_input_ids = test_item.pop('decoder_input_ids')
teacher_decoder_input_ids = test_item.pop('teacher_decoder_input_ids', None)

if teacher_decoder_input_ids is None:
  teacher_decoder_input_ids = decoder_input_ids

attention_mask = test_item.get('attention_mask')
labels = test_item.pop('labels')
decoder_attention_mask = labels != -100
teacher_decoder_attention_mask = teacher_decoder_input_ids != 1

layer_idx = -2

with torch.no_grad():
  encoder_output = encoder_model(**test_item, output_hidden_states=False, output_attentions=False)
  out = model(**test_item, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_output, output_attentions=True, use_cache=True)
  teacher_out = teacher(input_ids=teacher_decoder_input_ids, attention_mask=teacher_decoder_attention_mask, encoder_hidden_states=encoder_output[0], encoder_attention_mask=attention_mask, output_attentions=True, use_cache=True)

attn_mask = torch.bmm(decoder_attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).cpu().detach()
teacher_attn_mask = torch.bmm(teacher_decoder_attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).cpu().detach()
vr_attn_mask = torch.bmm(attention_mask.float().unsqueeze(2), attention_mask.float().unsqueeze(1)).cpu().detach()

actual_attentions = trainer.reduce_actual_attn_map(out.cross_attentions[layer_idx].cpu().detach(), attn_mask.bool())
guide_attentions = trainer.reduce_guide_attn_map(teacher_out.cross_attentions[layer_idx].cpu().detach(), teacher_attn_mask.bool())

actual_vr = trainer.reduce_actual_value_relations(out.past_key_values[layer_idx][-2:][1].cpu().detach(), vr_attn_mask.bool())
guide_vr = trainer.reduce_guide_value_relations(teacher_out.past_key_values[layer_idx][-2:][1].cpu().detach(), vr_attn_mask.bool())

if token_transform is not None:
  guide_attentions = torch.matmul(token_transform.cpu().detach(), guide_attentions)

  no_match_mask = torch.unsqueeze(token_transform.sum(dim=-1), dim=-1)
  no_match_mask = torch.repeat_interleave(no_match_mask, attn_mask.shape[-1], dim=-1)
  attn_mask = attn_mask * no_match_mask.cpu().detach()

plt.imshow(guide_attentions[0], aspect="auto", cmap="viridis", vmax=0.01)
plt.colorbar()
plt.show()

plt.imshow(actual_attentions[0].exp(), aspect="auto", cmap="viridis", vmax=0.01)
plt.colorbar()
plt.show()

plt.imshow(guide_vr[0][0], aspect="auto", vmax=0.3)
plt.colorbar()
plt.show()

plt.imshow(actual_vr[0][0].exp(), aspect="auto", vmax=0.3)
plt.colorbar()
plt.show()

plt.imshow(attn_mask[0], aspect="auto")
plt.colorbar()
plt.show()


vr_loss = kl_loss(actual_vr, guide_vr)
vr_loss = (vr_loss * vr_attn_mask).sum(dim=[1,2])
vr_loss = vr_loss / vr_attn_mask.sum(dim=1)[:,0] #Divide loss by number of tokens
vr_loss = vr_loss.sum() / guide_vr.size(0)

print("VR loss: ", vr_loss)

attn_loss = kl_loss(actual_attentions, guide_attentions)
attn_loss = (attn_loss * attn_mask).sum(dim=[1,2])
attn_loss = attn_loss / attn_mask.sum(dim=1)[:,0] #Divide loss by number of tokens
attn_loss = attn_loss.sum() / guide_attentions.size(0) # to have batch_mean

print("Attn loss: ", attn_loss)

#Upload the Model

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#for mBART:
#reassign encoder
model.model.encoder = encoder_model
model.push_to_hub("josh-oo/mbart-ts-distil", commit_message="Distilled layers: [r, r, -1] (0.75 vr/kr)")

In [None]:
#for mBART + GPT-2
model.encoder = None
model.push_to_hub("josh-oo/calibrated-decoder", commit_message="Distilled layers: [r, r, -1] (0.75 vr/kr)")

##Auto Disconnect from Colab to Save Credits

In [None]:
from google.colab import runtime
runtime.unassign()