In [None]:
!pip install torch transformers datasets sentencepiece
!pip install tensorboard scikit-learn psutil sacrebleu rouge-score tensorflow_datasets pytorch-lightning matplotlib git-python faiss-cpu streamlit elasticsearch nltk pandas datasets fire pytest conllu sentencepiece protobuf
!pip install jax jaxlib
!#pip install torch-lr-finder
!pip install wandb



In [None]:
# imports 
import logging
from transformers import PegasusTokenizerFast, PegasusForConditionalGeneration,PegasusConfig,AutoTokenizer,AutoModelForSeq2SeqLM
import datasets

import torch
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F

from typing import Callable, Dict, Iterable, List, Tuple, Union
from transformers import EvalPrediction, PreTrainedTokenizer
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from transformers import AdamW
import wandb
import gc

import numpy as np

from torch.utils.tensorboard import SummaryWriter


In [None]:
# Logging

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

writer = SummaryWriter()

In [None]:
# Teacher model and Tokenizer

tokenizer = AutoTokenizer.from_pretrained('google/pegasus-gigaword')

teacher = AutoModelForSeq2SeqLM.from_pretrained('google/pegasus-gigaword')
#copy_teacher = AutoModelForSeq2SeqLM.from_pretrained('google/pegasus-large')

In [None]:
#Student configuration

import warnings
import torch
from torch import nn
from typing import Optional, Tuple, List, Union
from transformers import PegasusModel, PegasusConfig, PegasusForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers import SummarizationPipeline

students_config_book = {
    '2': PegasusConfig(encoder_layers=2, decoder_layers=2),
    '4': PegasusConfig(encoder_layers=4, decoder_layers=4),
    '6': PegasusConfig(encoder_layers=6, decoder_layers=6),
    '8': PegasusConfig(encoder_layers=8, decoder_layers=8),
    '10': PegasusConfig(encoder_layers=10, decoder_layers=10),
    '12': PegasusConfig(encoder_layers=12, decoder_layers=12),
    '16': PegasusConfig(encoder_layers=16, decoder_layers=16)
}


LAYERS_TO_COPY = {   
    4:{
        1: [0],
        2: [0, 3],
        3: [0, 1, 3],
        4: [0, 1, 2, 3],
    },
    8:{
        1: [0],
        2: [0, 7],
        3: [0, 4, 7],
        4: [0, 3, 6, 7],
        6: [0, 2, 3, 5, 6, 7],
        8: list(range(8)),  
    },    
    12: {
        1: [0],
        2: [0, 11],
        3: [0, 6, 11],
        4: [0, 4, 9, 11],
        6: [0, 2, 5, 8, 10, 11],
        8: [0, 1, 3, 5, 7, 9, 10, 11],
        12: list(range(12)),  
    },
    16: {  # maps  num layers in student -> which teacher layers to copy
        1: [0],
        2: [0, 15],
        3: [0, 8, 15],
        4: [0, 5, 10, 15],
        6: [0, 3, 6, 9, 12, 15],
        8: [0, 2, 4, 6, 8, 10, 12, 15],
        9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
        12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
        16: list(range(16)),
    },}
LAYERS_TO_SUPERVISE = {
    # maps  num layers in student -> which teacher layers to copy.
    8: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
    12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 4:[1, 3, 7, 11],6: [1, 3, 5, 8, 10, 11], 8:[1,2,3,5,7,8,9,11] },
    16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15], 12:[1,2,3,5,7,8,9,11,12,13,14,15]},
}


def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy) -> None:
    layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
    assert len(dest_layers) == len(
        layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
    dest_layers.load_state_dict(layers_to_copy.state_dict())

# Copied from transformers.models.bart.modeling_bart.shift_tokens_right


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


def pick_layers_to_copy(n_student, n_teacher):
    try:
        val = LAYERS_TO_COPY[n_teacher][n_student]
        return val
    except KeyError:
        if n_student != n_teacher:
            warnings.warn(
                f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
            )
        return list(range(n_student))

def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
    """Used or the --supervise_forward kwarg"""
    if n_student > n_teacher:
        raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
    elif n_teacher == n_student:
        return list(range(n_teacher))
    elif n_student == 1:
        return [n_teacher - 1]
    else:
        return LAYERS_TO_SUPERVISE[n_teacher][n_student]

def create_student_with_configuration(teacher,
                                      e=None,
                                      d=None,
                                      copy_first_teacher_layers = False,
                                      save_path='./student'):

    teacher.eval()
    teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
    init_kwargs = teacher.config.to_diff_dict()
    if e is None:
        e = teacher_e
    if d is None:
        d = teacher_d
    init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
    student_cfg = teacher.config_class(**init_kwargs)
    student = AutoModelForSeq2SeqLM.from_config(student_cfg)
    # Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
    info = student.load_state_dict(teacher.state_dict(), strict=False)
    # every student key should have a teacher keys.
    assert info.missing_keys == [], info.missing_keys

    if copy_first_teacher_layers:  # Our copying is done. We just log and save
        e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
        #student.save_pretrained(save_path)
        return student, e_layers_to_copy, d_layers_to_copy

    # Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
    e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
    d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)

    copy_layers(teacher.model.encoder.layers,
                student.model.encoder.layers, e_layers_to_copy)
    copy_layers(teacher.model.decoder.layers,
                student.model.decoder.layers, d_layers_to_copy)

    student.config.init_metadata = dict(
        teacher_type=teacher.config.model_type,
        copied_encoder_layers=e_layers_to_copy,
        copied_decoder_layers=d_layers_to_copy,
    )
    #student.save_pretrained(save_path)
    # Save information about copying for easier reproducibility

    return student, e_layers_to_copy, d_layers_to_copy
#student = create_student_with_configuration(teacher,
#                                      e=4,
#                                      d=4,
#                                      copy_first_teacher_layers = False,
#                                      save_path='./student')
#import gc
#del copy_teacher
gc.collect()


4

# Data Loading

In [None]:
#student

In [None]:
class PegasusDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])  # torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels["input_ids"])


def prepare_data(model_name, 
                 train_texts, train_labels, 
                 val_texts=None, val_labels=None, 
                 test_texts=None, test_labels=None):
  """
  Prepare input data for model fine-tuning
  """
  tokenizer = AutoTokenizer.from_pretrained(model_name)

  prepare_val = False if val_texts is None or val_labels is None else True
  prepare_test = False if test_texts is None or test_labels is None else True

  def tokenize_data(texts, labels):
    encodings = tokenizer(texts, truncation=True, padding='longest')
    decodings = tokenizer(labels, truncation=True, padding='longest')
    dataset_tokenized = PegasusDataset(encodings, decodings)
    return dataset_tokenized

  train_dataset = tokenize_data(train_texts, train_labels)
  val_dataset = tokenize_data(val_texts, val_labels) if prepare_val else None
  test_dataset = tokenize_data(test_texts, test_labels) if prepare_test else None

  return train_dataset, val_dataset, test_dataset

In [None]:
dataset = datasets.load_dataset('gigaword')

#source data
train_texts, train_labels = dataset['train']['document'][:100000], dataset['train']['summary'][:100000]
valid_texts, valid_labels = dataset['validation']['document'][:10000], dataset['validation']['summary'][:10000]
test_texts, test_labels = dataset['test']['document'][:1000], dataset['test']['summary'][:1000]
train_dataset, valid_dataset, test_dataset = prepare_data('google/pegasus-gigaword', train_texts, train_labels,valid_texts, valid_labels,test_texts, test_labels)

del dataset 
gc.collect()



0

In [None]:
len(train_texts)

100000

In [None]:
#train_dataloader = DataLoader(train_dataset,batch_size=128)
#test_dataloader = DataLoader(test_dataset,batch_size=128)
#validation_dataloader = DataLoader(valid_dataset,batch_size=128)

In [None]:
train_dataloader = DataLoader(train_dataset,batch_size=24,num_workers = 2)
test_dataloader = DataLoader(test_dataset,batch_size=24,num_workers = 2)
validation_dataloader = DataLoader(valid_dataset,batch_size=24,num_workers = 2)

In [None]:
#next(iter(train_dataloader))
#next(iter(test_dataloader))
len(iter(train_dataloader))


4167

# Utils

In [None]:
!pip install nltk
import re

from filelock import FileLock


try:
    import nltk

    NLTK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
    NLTK_AVAILABLE = False

if NLTK_AVAILABLE:
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


def add_newline_to_end_of_each_sentence(x: str) -> str:
    """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
    re.sub("<n>", "", x)  # remove pegasus newline char
    assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)"
    return "\n".join(nltk.sent_tokenize(x))



05/16/2021 10:46:15 - INFO - filelock -   Lock 140389981684176 acquired on .lock
05/16/2021 10:46:16 - INFO - filelock -   Lock 140389981684176 released on .lock


In [None]:
## ROUGE Utils
from rouge_score import rouge_scorer, scoring

ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]


def extract_rouge_mid_statistics(dct):
    new_dict = {}
    for k1, v1 in dct.items():
        mid = v1.mid
        new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
    return new_dict


def calculate_rouge(
    pred_lns: List[str],
    tgt_lns: List[str],
    use_stemmer=True,
    rouge_keys=ROUGE_KEYS,
    return_precision_and_recall=False,
    bootstrap_aggregation=True,
    newline_sep=True,
) -> Dict:
    """Calculate rouge using rouge_scorer package.

    Args:
        pred_lns: list of summaries generated by model
        tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
        use_stemmer:  Bool indicating whether Porter stemmer should be used to
        strip word suffixes to improve matching.
        rouge_keys:  which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
        return_precision_and_recall: (False) whether to also return precision and recall.
        bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
            this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
        newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
        on multi sentence summaries (CNN/DM dataset).

    Returns:
         Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys

    """
    scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
    aggregator = scoring.BootstrapAggregator()
    for pred, tgt in zip(tgt_lns, pred_lns):
        # rougeLsum expects "\n" separated sentences within a summary
        if newline_sep:
            pred = add_newline_to_end_of_each_sentence(pred)
            tgt = add_newline_to_end_of_each_sentence(tgt)
        scores = scorer.score(pred, tgt)
        aggregator.add_scores(scores)

    if bootstrap_aggregation:
        result = aggregator.aggregate()
        if return_precision_and_recall:
            return extract_rouge_mid_statistics(result)  # here we return dict
        else:
            return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}

    else:
        return aggregator._scores  # here we return defaultdict(list)

In [None]:
def freeze(model):
  for param in model.parameters():
    param.requires_grad = False


# Loss

In [None]:
def cross_entropy_loss(logits, labels,label_smoothing,pad_token_id):
    lprobs = F.log_softmax(logits, dim=-1)
    student_lm_loss, _ = label_smoothed_nll_loss(
                lprobs, labels, label_smoothing, ignore_index=pad_token_id
            )
    return student_lm_loss

def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
    return loss, nll_loss

def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
    """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
    msg = "expected list or tuple for hidden_states, got tensor of shape: "
    mask = attention_mask.to(hidden_states[0])
    valid_count = mask.sum() * hidden_states[0].size(-1)
    student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
    teacher_states = torch.stack([hidden_states_T[j] for j in matches])
    if normalize_hidden:
        student_states = F.layer_norm(student_states, student_states.shape[1:])
        teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
    mse = F.mse_loss(student_states, teacher_states, reduction="none")
    masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
    return masked_mse


def calc_ce_loss(mask, s_logits, t_logits,temperature=2):
    """Copy pasted from distillbert (transformers/examples/distillation/)"""
    # mask has False at padding_idx
    sel_mask = mask[:, :, None].expand_as(s_logits)

    #print(sel_mask.shape ,s_logits.shape ,t_logits.shape )
    
    vocab_size = s_logits.size(-1)
    s_logits_slct = torch.masked_select(s_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
    t_logits_slct = torch.masked_select(t_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
    s_logits_slct = s_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
    t_logits_slct = t_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
    ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
    loss_ce = (
        ce_loss_fct(
            F.log_softmax(s_logits_slct / temperature, dim=-1),
            F.softmax(t_logits_slct / temperature, dim=-1),
        )
        * (temperature) ** 2 
    )
    return loss_ce

def shift_tokens_right(input_ids, pad_token_id):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    #print(pad_token_id, input_ids)
    x= (input_ids.ne(pad_token_id).sum(dim=1) - 1)
    index_of_eos = x.unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens

def blended_loss(teacher,student,batch,labels,e_layer_ids, d_layer_ids,mean_ce, mean_logits, mean_hidden,pad_token_id): #TODO Calculate the total loss
        
        alpha_hid = 1/mean_hidden
        alpha_ce = 1/mean_ce
        alpha_mlm= 1/mean_logits
        

        student_encoder_layers = student.config.encoder_layers
        student_decoder_layers = student.config.decoder_layers
        teacher_encoder_layers = teacher.config.encoder_layers
        teacher_decoder_layers = teacher.config.decoder_layers

        e_matches = get_layers_to_supervise(
                    n_student=len(e_layer_ids), n_teacher=teacher_encoder_layers
                )
        d_matches = get_layers_to_supervise(
                    n_student=len(d_layer_ids), n_teacher=teacher_decoder_layers
                )
        
        different_base_models = False
        do_calc_hidden_loss = (not different_base_models) and alpha_hid > 0
        different_encoder = different_base_models or (student.config.encoder_layers != teacher.config.encoder_layers)
        
        #input_ids,src_mask = batch['input_ids'], batch['attention_mask']
        
        decoder_input_ids = shift_tokens_right(labels, pad_token_id)
        
        student_outputs = student(batch['input_ids'],
            attention_mask= batch['attention_mask'],
            decoder_input_ids= decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
            use_cache=False)
        
       
        
        #lm_logits = student_outputs["logits"] 
        label_smoothing = 0.2
        student_loss = cross_entropy_loss(student_outputs["logits"], labels,label_smoothing,pad_token_id)

        def zero_tensor():
            return torch.tensor(0.0).type_as(student_loss)
        
        teacher_enc_outputs = student_outputs[
            "encoder_last_hidden_state"
        ]  # use this unless self.different_base_models
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        
        
        
        if different_encoder:  # compute encoder hidden state loss
            all_teacher_encoder_outputs = teacher.get_encoder()(
                batch['input_ids'],
                attention_mask=batch['attention_mask'],
                output_hidden_states=True,
            )
            #print(all_teacher_encoder_outputs['hidden_states'].shape)
            if different_base_models:
                teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
            elif do_calc_hidden_loss:
                hid_loss_enc = calc_hidden_loss(
                    batch['attention_mask'],
                    student_outputs["encoder_hidden_states"],
                    all_teacher_encoder_outputs["hidden_states"],
                    e_matches,
                    normalize_hidden=True,
                )
                #wandb.log({'encoder hidden loss': hid_loss_enc})
        # decoder_input_ids for teacher [8,1] (zeros)
        #decodeIds = torch.cuda.LongTensor([0,0,0,0,0,0,0,0]).reshape(1,8)
        
      

        teacher_outputs = teacher(
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            encoder_outputs=(teacher_enc_outputs,),
            decoder_input_ids= decoder_input_ids,
            output_hidden_states=do_calc_hidden_loss,
            use_cache=False,  # since we are not passing labels, never let this default to True
        )

        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = calc_ce_loss(dec_mask, student_outputs["logits"], teacher_outputs["logits"])
        if do_calc_hidden_loss:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = calc_hidden_loss(
                dec_mask,
                student_outputs["decoder_hidden_states"],
                teacher_outputs["decoder_hidden_states"],
                d_matches,
                normalize_hidden=True,
            )
            #wandb.log({'decoder hidden loss': hid_loss_dec})
        
        
        losses = {'ce_KD_loss':loss_ce, 'student loss Logits':student_loss, 'Hidden': hid_loss_dec}
        blended_loss = (
            alpha_ce * loss_ce
            + alpha_mlm * student_loss
            + alpha_hid * (hid_loss_enc + hid_loss_dec)
        )

        del student_outputs
        del teacher_outputs
        del batch
        del labels
        del all_teacher_encoder_outputs

        gc.collect()
        torch.cuda.empty_cache()

        return blended_loss, losses

        

# Training Loop 

In [None]:
!nvidia-smi

Sun May 16 10:46:16 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
student, e_layers_list, d_layers_list = create_student_with_configuration(
                                      teacher,
                                      e=4,
                                      d=4,
                                      copy_first_teacher_layers = False,
                                      save_path='./student')

optimizer = AdamW(student.parameters(), lr=5e-5)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

pad_token_id = tokenizer.pad_token_id

student.to('cuda')
teacher.to('cuda')
!nvidia-smi


Sun May 16 10:46:43 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W /  70W |   4050MiB / 15109MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
#freeze the teacher model 
freeze(teacher)
!nvidia-smi
for param in student.model.shared.parameters():
  param.requires_grad = False
for param in student.model.encoder.embed_tokens.parameters():
  param.requires_grad = False
for param in student.model.encoder.embed_positions.parameters():
  param.requires_grad = False
for param in student.model.decoder.embed_tokens.parameters():
  param.requires_grad = False
for param in student.model.decoder.embed_positions.parameters():
  param.requires_grad = False

Sun May 16 10:46:43 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W /  70W |   4050MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
 # Evaluation LOOP
with torch.no_grad():
      student.eval()
      
      test_loss_logits=[]
      test_loss_hidden=[]
      test_loss_ce=[]
      for i, test_batch in enumerate(test_dataloader):
          y = test_batch['labels'].to('cuda')
          x = {
                  'input_ids':test_batch['input_ids'].to('cuda'),
                  'attention_mask':test_batch['attention_mask'].to('cuda')
              } 
            
          prediction = student.generate(**x)
        
          loss , all_losses = blended_loss(teacher,student,x,y,e_layers_list, d_layers_list,1,1,1,pad_token_id)      
          test_loss_logits.append(all_losses['student loss Logits'])
          test_loss_hidden.append(all_losses['Hidden'])
          test_loss_ce.append(all_losses['ce_KD_loss'])
          #wandb.log(all_losses)
          writer.add_scalars("Losses/initial Evaluation", all_losses,i)
          del x
          del y
          gc.collect()
          torch.cuda.empty_cache()
          
      mean_ce, mean_logits, mean_hidden = torch.mean(torch.tensor(test_loss_ce)), torch.mean(torch.tensor(test_loss_logits)), torch.mean(torch.tensor(test_loss_hidden))   
      writer.flush()  
      del test_loss_ce 
      del test_loss_logits 
      del test_loss_hidden
      gc.collect()
      torch.cuda.empty_cache()

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
  # ----------------
  # TRAINING LOOP
  # ----------------
  num_epochs = 5
  for epoch in range(num_epochs):
    i = 0
    # TRAINING LOOP
    for train_batch in train_dataloader:
      i=i+1
      y = train_batch['labels'].to('cuda')
      x = {
          'input_ids':train_batch['input_ids'].to('cuda'),
          'attention_mask':train_batch['attention_mask'].to('cuda')
      }
      #print(y , x)
      decoder_input_ids = shift_tokens_right(y, pad_token_id)

      # [ 8 zeros -> feed only
      #wandb.watch(student)
      
      student.train(True)
      teacher.eval()

      student(x['input_ids'],
              attention_mask=x['attention_mask'],
              decoder_input_ids=decoder_input_ids,
              output_hidden_states=True,
              output_attentions=False,
              use_cache=False)

      loss, all_losses = blended_loss(teacher, student, x, y, e_layers_list, d_layers_list, mean_ce, mean_logits, mean_hidden, pad_token_id)

      print(f'epoch|{epoch} iteration {i} train loss: {loss.item()}')

      loss.backward()

      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      #wandb.log({"loss": loss})
      writer.add_scalar("Loss/train", loss, i)  
      #print(torch.cuda.get_device_name(0))
      #print('Memory Usage:')
      #print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
      #print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
      del x
      del y
      del train_batch
      gc.collect()
      torch.cuda.empty_cache()
    writer.flush()

    # VALIDATION LOOP
    with torch.no_grad():
      val_loss = []
      student.eval()
      for val_batch in validation_dataloader:
          y = val_batch['labels'].to('cuda')
          x = {
                'input_ids':val_batch['input_ids'].to('cuda'),
                'attention_mask':val_batch['attention_mask'].to('cuda')
            } 
          decoder_input_ids = shift_tokens_right(y, pad_token_id)
              
          logits = student(x['input_ids'],
                attention_mask=x['attention_mask'],
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True,
                output_attentions=False,
                use_cache=False)
          
          loss , all_losses  = blended_loss(teacher,student,x,y,e_layers_list, d_layers_list,mean_ce, mean_logits, mean_hidden,pad_token_id) 
          val_loss.append(loss.item())
          del x
          del y
          del val_batch
          gc.collect()
          torch.cuda.empty_cache()

      val_losses = torch.mean(torch.tensor(val_loss))
      print('val_loss: ', val_losses.item())


    # Evaluation LOOP
    with torch.no_grad():
      student.eval()
      all_labels = []
      all_preds = []
      test_loss=[]
      for i, test_batch in enumerate(test_dataloader):
          y = test_batch['labels'].to('cuda')
          x = {
                  'input_ids':test_batch['input_ids'].to('cuda'),
                  'attention_mask':test_batch['attention_mask'].to('cuda')
              } 
            
          prediction = student.generate(**x)
        
          all_labels.append(y)
          all_preds.append(prediction)
          loss , all_losses = blended_loss(teacher,student,x,y,e_layers_list, d_layers_list,mean_ce, mean_logits, mean_hidden,pad_token_id)       
          test_loss.append(loss.item())
          #wandb.log(all_losses)
          writer.add_scalars("Losses/Evaluation", all_losses, i)

      test_losses = torch.mean(torch.tensor(test_loss))
      print('test_loss: ', test_losses.item())
      preds = [tokenizer.decode(pred[0]) for pred in all_preds]
      lbls = [tokenizer.decode(lbl[0]) for lbl in all_labels]
      rouge_score = calculate_rouge(pred_lns=preds,tgt_lns=lbls)
      writer.add_scalars("RougeScores/Evaluation", rouge_score, epoch)

      #wandb.log(rouge_score)
      writer.flush()
      print(rouge_score)

    PATH = f'/content/drive/MyDrive/GP/student_checkpoint_epoch_{epoch}.pt'
    torch.save({
            'epoch': epoch,
            'model_state_dict': student.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': all_losses,
            }, PATH)
    writer.close()

epoch|0 iteration 1 train loss: 4.289752960205078
epoch|0 iteration 2 train loss: 4.206548690795898
epoch|0 iteration 3 train loss: 4.083798885345459
epoch|0 iteration 4 train loss: 4.094593048095703
epoch|0 iteration 5 train loss: 3.8483574390411377
epoch|0 iteration 6 train loss: 3.877967357635498
epoch|0 iteration 7 train loss: 3.757674217224121
epoch|0 iteration 8 train loss: 3.7347192764282227
epoch|0 iteration 9 train loss: 3.702779531478882
epoch|0 iteration 10 train loss: 3.6680104732513428
epoch|0 iteration 11 train loss: 3.6243414878845215
epoch|0 iteration 12 train loss: 3.5978105068206787
epoch|0 iteration 13 train loss: 3.585702419281006
epoch|0 iteration 14 train loss: 3.5396499633789062
epoch|0 iteration 15 train loss: 3.599029064178467
epoch|0 iteration 16 train loss: 3.6442253589630127
epoch|0 iteration 17 train loss: 3.5897066593170166
epoch|0 iteration 18 train loss: 3.4608511924743652
epoch|0 iteration 19 train loss: 3.538313865661621
epoch|0 iteration 20 train loss

In [None]:
!pip install tensorboard --upgrade
!tensorboard dev upload --logdir runs --name "pegasus distillation" --description "Simple pipeline for distilling pegasus"

Requirement already up-to-date: tensorboard in /usr/local/lib/python3.7/dist-packages (2.5.0)
2021-05-16 10:24:01.793778: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Data for the "text" plugin is now uploaded to TensorBoard.dev! Note that uploaded data is public. If you do not want to upload data for this plugin, use the "--plugins" command line argument.
Upload started and will continue reading any new data as it's added to the logdir.

To stop uploading, press Ctrl-C.

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/fdL5EJYRRKqHaOwdocZUGA/

[1m[2021-05-16T10:24:03][0m Started scanning logdir.
[1m[2021-05-16T10:24:03][0m Total uploaded: 142 scalars, 0 tensors, 0 binary objects
Ctrl-C
Traceback (most recent call last):
  File "/usr/local/bin/tensorboard", line 8, in <module>
    sys.exit(run_main())
  File "/usr/local/lib/python3.7/dist-packages/tensorboard/main.py", line 

In [None]:
def pretty_size(size):
	"""Pretty prints a torch.Size object"""
	assert(isinstance(size, torch.Size))
	return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
	"""Prints a list of the Tensors being tracked by the garbage collector."""
	import gc
	total_size = 0
	for obj in gc.get_objects():
		try:
			if torch.is_tensor(obj):
				if not gpu_only or obj.is_cuda:
					print("%s:%s%s %s" % (type(obj).__name__, 
										  " GPU" if obj.is_cuda else "",
										  " pinned" if obj.is_pinned else "",
										  pretty_size(obj.size())))
					total_size += obj.numel()
			elif hasattr(obj, "data") and torch.is_tensor(obj.data):
				if not gpu_only or obj.is_cuda:
					print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
												   type(obj.data).__name__, 
												   " GPU" if obj.is_cuda else "",
												   " pinned" if obj.data.is_pinned else "",
												   " grad" if obj.requires_grad else "", 
												   " volatile" if obj.volatile else "",
												   pretty_size(obj.data.size())))
					total_size += obj.data.numel()
		except Exception as e:
			pass        
	print("Total size:", total_size)

In [None]:
dump_tensors()

In [None]:
!nvidia-smi

In [None]:
rouge_score