In [4]:
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
"""
import argparse
import json
import os
import pickle
import shutil

import numpy as np
import torch
from torch import nn

from causal_distiller import *
from lm_seqs_dataset import LmSeqsDataset
from transformers import (
    BertConfig,
    BertTokenizer,
    DistilBertConfig,
    DistilBertTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
)
from models.modeling_distilbert import DistilBertForMaskedLM # we need to customize it a little.
from models.modeling_bert import BertForMaskedLM # we need to customize it a little.
from utils import git_log, init_gpu_params, logger, set_seed


MODEL_CLASSES = {
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
}


In [5]:
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil the student.
    Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import math
import os
import time

import psutil
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

import argparse
import json
import os
import pickle
import shutil
import random

import numpy as np
import torch

from distiller import Distiller
from lm_seqs_dataset import LmSeqsDataset
from transformers import (
    AutoConfig,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    DistilBertConfig,
    DistilBertTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    AutoTokenizer,
    AutoModelForMaskedLM,
)
from utils import git_log, init_gpu_params, logger, set_seed
from datasets import load_dataset
from counterfactual_utils import *
import wandb
from models.modeling_distilbert import DistilBertForMaskedLM

# Examples of interchange.
# activations_counterfactual_teacher = get_activation_at(
#     teacher_bert,
#     batch["input_ids"],
#     batch["attention_mask"],
#     variable_names=["$L:1$H:1$[0:32]"]
# )
# interchange_with_activation_at(
#     teacher_bert,
#     batch["input_ids"],
#     batch["attention_mask"],
#     interchanged_variables=[torch.zeros(32, 512, 32)],
#     variable_names=["$L:1$H:1$[0:32]"]
# )

class CausalDistiller:
    def __init__(
        self, params: dict, dataset: LmSeqsDataset, 
        token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
    ):
        if params.is_wandb:
            run = wandb.init(
                project="Causal-BERT-Distillation", 
                entity="wuzhengx",
                name=params.run_name,
            )
            wandb.config.update(params)
        self.is_wandb = params.is_wandb
        
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher
        
        # causal neuron mappings.
        self.deserialized_interchange_variable_mappings = []
        with open(params.neuron_mapping) as json_file:
            neuron_mapping = json.load(json_file)
            logger.info(f"Neuron Mapping: {neuron_mapping}")
            interchange_variable_mappings = neuron_mapping["interchange_variable_mappings"]
            for m in interchange_variable_mappings:
                teacher_deserialized_variables = []
                for variable in m["teacher_variable_names"]:
                    teacher_deserialized_variables.append(deserialize_variable_name(variable))
                student_deserialized_variables = []
                for variable in m["student_variable_names"]:
                    student_deserialized_variables.append(deserialize_variable_name(variable))
                self.deserialized_interchange_variable_mappings += [
                    [teacher_deserialized_variables, student_deserialized_variables]
                ]

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        # overwrite slightly on this.
        if params.local_rank == -1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)
            
        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos
        self.alpha_causal = params.alpha_causal

        self.mlm = params.mlm
        if self.mlm:
            logger.info("Using MLM loss for LM step.")
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(torch.device("cuda")) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(torch.device("cuda")) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info("Using CLM loss for LM step.")

        self.interchange_mlm = params.interchange_mlm
        self.interchange_prop = params.interchange_prop
            
        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0

        self.last_loss_causal_ce = 0
        self.last_teacher_interchange_efficacy = 0
        self.last_student_interchange_efficacy = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
        )

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i"
            % sum([p.numel() for p in self.student.parameters() if p.requires_grad])
        )
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(
            optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
        )

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
        )

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student, self.optimizer, opt_level=self.params.fp16_opt_level
            )
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
                self.student = DistributedDataParallel(self.student)
            else:
                if params.local_rank == -1:
                    logger.info("Using nn.DataParallel for the teacher model.")
                    # teacher also use multi-GPU.
                    self.teacher = torch.nn.DataParallel(self.teacher)
                    self.teacher.to(torch.device("cuda")) # no rank is needed!

                    logger.info("Using nn.DataParallel for the student model.")
                    self.student = torch.nn.DataParallel(self.student)
                    self.student.to(torch.device("cuda")) # no rank is needed!
                else:
                
                    from torch.nn.parallel import DistributedDataParallel

                    logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
                    self.student = DistributedDataParallel(
                        self.student,
                        device_ids=[params.local_rank],
                        output_device=params.local_rank,
                        find_unused_parameters=True,
                    )

        self.is_master = params.is_master
        
        
    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
        probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True).to(_token_ids_real.device)
        _token_ids = (
            _token_ids_mask * (probs == 0).long()
            + _token_ids_real * (probs == 1).long()
            + _token_ids_rand * (probs == 2).long()
        )
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self, batch):
        """
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids["pad_token"]
            else:
                pad_id = self.params.special_tok_ids["unk_token"]
            padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def prepare_interchange_position(self, lengths, dual_lengths):
        interchange_prop = self.interchange_prop
        batch_size = lengths.shape[0]
        interchange_position = []
        for i in range(0, batch_size):
            min_len = min(lengths[i].tolist(), dual_lengths[i].tolist())
            interchange_count = int(min_len*interchange_prop)
            start_index = random.randint(0, lengths[i].tolist()-interchange_count)
            end_index = start_index + interchange_count
            dual_start_index = random.randint(0, dual_lengths[i].tolist()-interchange_count)
            dual_end_index = dual_start_index + interchange_count
            interchange_position += [[start_index, end_index, dual_start_index, dual_end_index]]
        interchange_position = torch.tensor(interchange_position, dtype=torch.long).to(lengths.device)
        return interchange_position
    
    def train(self):
        """
        The real training loop.
        """
        if self.is_master:
            logger.info("Starting training")
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")

            iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                token_ids, lengths, dual_token_ids, dual_lengths = batch

                if self.params.n_gpu > 0:
                    token_ids = token_ids.to(torch.device("cuda"))
                    lengths = lengths.to(torch.device("cuda"))
                    dual_token_ids = dual_token_ids.to(torch.device("cuda"))
                    dual_lengths = dual_lengths.to(torch.device("cuda"))
                
                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=(token_ids, lengths))
                    dual_token_ids, dual_attn_mask, dual_lm_labels = self.prepare_batch_mlm(
                        batch=(dual_token_ids, dual_lengths)
                    )
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=(token_ids, lengths))
                    dual_token_ids, dual_attn_mask, dual_lm_labels = self.prepare_batch_clm(
                        batch=(dual_token_ids, dual_lengths)
                    )
                    
                # from length, let us get the intervention points?
                sampled_interchange_position = self.prepare_interchange_position(lengths, dual_lengths)
                
                self.step(
                    input_ids=token_ids, 
                    attention_mask=attn_mask, 
                    lm_labels=lm_labels,
                    dual_input_ids=dual_token_ids, 
                    dual_attention_mask=dual_attn_mask, 
                    dual_lm_labels=dual_lm_labels,
                    sampled_interchange_position=sampled_interchange_position,
                    is_parallel=self.params.parallel_crossway,
                    is_crossway=self.params.include_crossway,
                )
                iter_bar.update()
                iter_bar.set_postfix(
                    {
                        "Last_loss": f"{self.last_loss:.2f}", 
                         "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}", 
                         "Last_cf_loss": f"{self.last_loss_causal_ce:.2f}", 
                    }
                )
            iter_bar.close()

            if self.is_master:
                logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
            self.end_epoch()

        if self.is_master:
            logger.info("Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name="pytorch_model.bin")
            logger.info("Training is finished")

    def step(
        self, input_ids: torch.tensor, 
        attention_mask: torch.tensor, 
        lm_labels: torch.tensor,
        dual_input_ids: torch.tensor, 
        dual_attention_mask: torch.tensor, 
        dual_lm_labels: torch.tensor,
        sampled_interchange_position: torch.tensor,
        is_parallel=False,
        is_crossway=False,
    ):
        if is_parallel:
            assert is_crossway
            """
            If we enable crossway and parallel, we make 
            sure we compute pair-wise losses for two 
            examples together.
            Note that this requires larger GPUs.
            """
            self._step_parallel(
                input_ids,
                attention_mask,
                lm_labels,
                dual_input_ids,
                dual_attention_mask,
                dual_lm_labels,
                sampled_interchange_position,
            )
        else:
            """
            If it is not parallel, we will have two mini-step
            within each step. The second step will only backprop
            loss without updating the iteration, so the optimization
            is not affected.
            """
            if is_crossway:
                self._step(
                    input_ids,
                    attention_mask,
                    lm_labels,
                    dual_input_ids,
                    dual_attention_mask,
                    dual_lm_labels,
                    sampled_interchange_position,
                    skip_update_iter=True,
                )
                # the second mini-step for the reversed pair.
                self._step(
                    dual_input_ids,
                    dual_attention_mask,
                    dual_lm_labels,
                    input_ids,
                    attention_mask,
                    lm_labels,
                    sampled_interchange_position,
                    skip_update_iter=False,
                )
            else:
                """
                This subroutine will be the normal distillation
                with optional causal loss.
                """
                self._step(
                    input_ids,
                    attention_mask,
                    lm_labels,
                    dual_input_ids,
                    dual_attention_mask,
                    dual_lm_labels,
                    sampled_interchange_position,
                    skip_update_iter=False,
                )

    def _step(
        self, input_ids: torch.tensor, 
        attention_mask: torch.tensor, 
        lm_labels: torch.tensor,
        dual_input_ids: torch.tensor, 
        dual_attention_mask: torch.tensor, 
        dual_lm_labels: torch.tensor,
        sampled_interchange_position: torch.tensor,
        skip_update_iter=False,
    ):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids/dual_input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask/dual_attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels/dual_lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        # preparing for causal distillation.
        # we randomly select the pool of neurons to interchange.
        selector = random.randint(0, len(self.deserialized_interchange_variable_mappings)-1)
        interchange_variable_mapping = self.deserialized_interchange_variable_mappings[selector]
        teacher_variable_names = random.choice(interchange_variable_mapping[0])
        student_variable_names = random.choice(interchange_variable_mapping[1])
        teacher_interchanged_variables_mapping = {}
        student_interchanged_variables_mapping = {}
        # we need to do the interchange here.
        for i, variable in enumerate(teacher_variable_names):
            layer_index, head_index, LOC = parse_variable_name(variable)
            if layer_index in teacher_interchanged_variables_mapping:
                teacher_interchanged_variables_mapping[layer_index] += [(i, head_index, LOC)]
            else:
                teacher_interchanged_variables_mapping[layer_index] = [(i, head_index, LOC)]
        for i, variable in enumerate(student_variable_names):
            layer_index, head_index, LOC = parse_variable_name(variable)
            if layer_index in student_interchanged_variables_mapping:
                student_interchanged_variables_mapping[layer_index] += [(i, head_index, LOC)]
            else:
                student_interchanged_variables_mapping[layer_index] = [(i, head_index, LOC)]
        
        if self.mlm:
            with torch.no_grad():
                # teacher forward pass normal.
                teacher_outputs = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
                # dual on main example
                # teacher forward pass for interchange variables.
                dual_counterfactual_activations_teacher = get_activation_at(
                    self.teacher,
                    dual_input_ids, # this is different!
                    dual_attention_mask, # this is different!
                    variable_names=teacher_variable_names
                )
                # teacher forward pass for interchanged outputs.
                counterfactual_outputs_teacher = self.teacher(
                    input_ids=input_ids, # this is different!
                    attention_mask=attention_mask, # this is different!
                    interchanged_variables=dual_counterfactual_activations_teacher,
                    variable_names=teacher_interchanged_variables_mapping,
                    sampled_interchange_position=sampled_interchange_position,
                )
            t_logits, t_hidden_states = \
                teacher_outputs["logits"], teacher_outputs["hidden_states"]
            student_outputs = self.student(
                input_ids=input_ids, attention_mask=attention_mask,
                t_logits=t_logits,
                t_hidden_states=t_hidden_states,
                temperature=self.temperature,
                restrict_ce_to_mask=self.params.restrict_ce_to_mask,
                lm_labels=lm_labels,
                alpha_mlm=self.alpha_mlm,
                alpha_clm=self.alpha_clm,
                alpha_mse=self.alpha_mse,
                alpha_cos=self.alpha_cos,
            )  # (bs, seq_length, voc_size)
            s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
            causal_t_logits, causal_t_hidden_states = \
                counterfactual_outputs_teacher["logits"], counterfactual_outputs_teacher["hidden_states"]
        else:
            assert False # we are not supporting this branch!
        
        # standard losses.
        loss_ce = student_outputs["loss_ce"].mean() if self.multi_gpu else student_outputs["loss_ce"]
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.0:
            loss_mlm = student_outputs["loss_mlm"].mean() if self.multi_gpu else student_outputs["loss_mlm"]
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.0:
            loss_clm = student_outputs["loss_clm"].mean() if self.multi_gpu else student_outputs["loss_clm"]
            loss += self.alpha_clm * loss_clm
        if self.alpha_mse > 0.0:
            loss_mse = student_outputs["loss_mse"].mean() if self.multi_gpu else student_outputs["loss_mse"]
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            loss_cos = student_outputs["loss_cos"].mean() if self.multi_gpu else student_outputs["loss_cos"]
            loss += self.alpha_cos * loss_cos
            
        # we need to get causal distillation loss!
        dual_counterfactual_activations_student = get_activation_at(
            self.student,
            dual_input_ids, # this is different!
            dual_attention_mask, # this is different!
            variable_names=student_variable_names
        )
        # dual on main.
        counterfactual_outputs_student = self.student(
            input_ids=input_ids, # this is different!
            attention_mask=attention_mask, # this is different!
            # interchange.
            interchanged_variables=dual_counterfactual_activations_student,
            variable_names=student_interchanged_variables_mapping,
            sampled_interchange_position=sampled_interchange_position,
            # loss.
            t_logits=t_logits,
            t_hidden_states=t_hidden_states,
            causal_t_logits=causal_t_logits,
            causal_t_hidden_states=causal_t_hidden_states,
            s_logits=s_logits,
            s_hidden_states=s_hidden_states,
            temperature=self.temperature,
            restrict_ce_to_mask=self.params.restrict_ce_to_mask,
        )
        # sanity check.
        assert "loss_ce" not in counterfactual_outputs_student
        assert "loss_mlm" not in counterfactual_outputs_student
        assert "loss_clm" not in counterfactual_outputs_student
        assert "loss_mse" not in counterfactual_outputs_student
        assert "loss_cos" not in counterfactual_outputs_student
        causal_loss_ce = counterfactual_outputs_student["causal_loss_ce"].mean() if self.multi_gpu else counterfactual_outputs_student["causal_loss_ce"]
        teacher_interchange_efficacy = \
            counterfactual_outputs_student["teacher_interchange_efficacy"].mean() if self.multi_gpu else counterfactual_outputs_student["teacher_interchange_efficacy"]
        student_interchange_efficacy = \
            counterfactual_outputs_student["student_interchange_efficacy"].mean() if self.multi_gpu else counterfactual_outputs_student["student_interchange_efficacy"]
        if self.alpha_causal > 0.0:
            loss += self.alpha_causal * causal_loss_ce
                
        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.0:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.0:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()
        # optional recording of the value.
        self.last_loss_causal_ce = causal_loss_ce.item()
        # record efficacy of the interchange.
        self.last_teacher_interchange_efficacy = teacher_interchange_efficacy.item()
        self.last_student_interchange_efficacy = student_interchange_efficacy.item()
            
        self.optimize(loss, skip_update_iter=skip_update_iter)

        self.n_sequences_epoch += input_ids.size(0)
        
    def _step_parallel(
        self, input_ids: torch.tensor, 
        attention_mask: torch.tensor, 
        lm_labels: torch.tensor,
        dual_input_ids: torch.tensor, 
        dual_attention_mask: torch.tensor, 
        dual_lm_labels: torch.tensor,
        sampled_interchange_position: torch.tensor,
    ):
        """
        WARNING: Parallel requires GPUs with larger memory. It involves computations across two examples
        with two parallel iterations.
        
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids/dual_input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask/dual_attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels/dual_lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        
        # preparing for causal distillation.
        # we randomly select the pool of neurons to interchange.
        selector = random.randint(0, len(self.deserialized_interchange_variable_mappings)-1)
        interchange_variable_mapping = self.deserialized_interchange_variable_mappings[selector]
        teacher_variable_names = random.choice(interchange_variable_mapping[0])
        student_variable_names = random.choice(interchange_variable_mapping[1])
        teacher_interchanged_variables_mapping = {}
        student_interchanged_variables_mapping = {}
        # we need to do the interchange here.
        for i, variable in enumerate(teacher_variable_names):
            layer_index, head_index, LOC = parse_variable_name(variable)
            if layer_index in teacher_interchanged_variables_mapping:
                teacher_interchanged_variables_mapping[layer_index] += [(i, head_index, LOC)]
            else:
                teacher_interchanged_variables_mapping[layer_index] = [(i, head_index, LOC)]
        for i, variable in enumerate(student_variable_names):
            layer_index, head_index, LOC = parse_variable_name(variable)
            if layer_index in student_interchanged_variables_mapping:
                student_interchanged_variables_mapping[layer_index] += [(i, head_index, LOC)]
            else:
                student_interchanged_variables_mapping[layer_index] = [(i, head_index, LOC)]
        
        if self.mlm:
            with torch.no_grad():
                # teacher forward pass normal.
                teacher_outputs = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
                # teacher forward pass normal for the dual example.
                dual_teacher_outputs = self.teacher(
                    input_ids=dual_input_ids, attention_mask=dual_attention_mask
                )  # (bs, seq_length, voc_size)
                
                # dual on main example
                # teacher forward pass for interchange variables.
                dual_counterfactual_activations_teacher = get_activation_at(
                    self.teacher,
                    dual_input_ids, # this is different!
                    dual_attention_mask, # this is different!
                    variable_names=teacher_variable_names
                )
                # teacher forward pass for interchanged outputs.
                counterfactual_outputs_teacher = self.teacher(
                    input_ids=input_ids, # this is different!
                    attention_mask=attention_mask, # this is different!
                    interchanged_variables=dual_counterfactual_activations_teacher,
                    variable_names=teacher_interchanged_variables_mapping,
                    sampled_interchange_position=sampled_interchange_position,
                )   
                
                # main on dual example
                # teacher forward pass for interchange variables.
                counterfactual_activations_teacher = get_activation_at(
                    self.teacher,
                    input_ids, # this is different!
                    attention_mask, # this is different!
                    variable_names=teacher_variable_names
                )
                # teacher forward pass for interchanged outputs.
                dual_counterfactual_outputs_teacher = self.teacher(
                    input_ids=dual_input_ids, # this is different!
                    attention_mask=dual_attention_mask, # this is different!
                    interchanged_variables=counterfactual_activations_teacher,
                    variable_names=teacher_interchanged_variables_mapping,
                    sampled_interchange_position=sampled_interchange_position,
                )
            t_logits, t_hidden_states = \
                teacher_outputs["logits"], teacher_outputs["hidden_states"]
            dual_t_logits, dual_t_hidden_states = \
                dual_teacher_outputs["logits"], dual_teacher_outputs["hidden_states"]
            student_outputs = self.student(
                input_ids=input_ids, attention_mask=attention_mask,
                t_logits=t_logits,
                t_hidden_states=t_hidden_states,
                temperature=self.temperature,
                restrict_ce_to_mask=self.params.restrict_ce_to_mask,
                lm_labels=lm_labels,
                alpha_mlm=self.alpha_mlm,
                alpha_clm=self.alpha_clm,
                alpha_mse=self.alpha_mse,
                alpha_cos=self.alpha_cos,
            )  # (bs, seq_length, voc_size)
            dual_student_outputs = self.student(
                input_ids=dual_input_ids, attention_mask=dual_attention_mask,
                t_logits=dual_t_logits,
                t_hidden_states=dual_t_hidden_states,
                temperature=self.temperature,
                restrict_ce_to_mask=self.params.restrict_ce_to_mask,
                lm_labels=lm_labels,
                alpha_mlm=self.alpha_mlm,
                alpha_clm=self.alpha_clm,
                alpha_mse=self.alpha_mse,
                alpha_cos=self.alpha_cos,
            )  # (bs, seq_length, voc_size)
            s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
            dual_s_logits, dual_s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]
            causal_t_logits, causal_t_hidden_states = \
                counterfactual_outputs_teacher["logits"], counterfactual_outputs_teacher["hidden_states"]
            dual_causal_t_logits, dual_causal_t_hidden_states = \
                counterfactual_outputs_teacher["logits"], counterfactual_outputs_teacher["hidden_states"]
        else:
            assert False # we are not supporting this branch!
        
        # standard losses.
        loss_ce = student_outputs["loss_ce"].mean() if self.multi_gpu else student_outputs["loss_ce"]
        loss_ce += dual_student_outputs["loss_ce"].mean() if self.multi_gpu else dual_student_outputs["loss_ce"]
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.0:
            loss_mlm = student_outputs["loss_mlm"].mean() if self.multi_gpu else student_outputs["loss_mlm"]
            loss_mlm += dual_student_outputs["loss_mlm"].mean() if self.multi_gpu else dual_student_outputs["loss_mlm"]
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.0:
            loss_clm = student_outputs["loss_clm"].mean() if self.multi_gpu else student_outputs["loss_clm"]
            loss_clm += dual_student_outputs["loss_clm"].mean() if self.multi_gpu else dual_student_outputs["loss_clm"]
            loss += self.alpha_clm * loss_clm
        if self.alpha_mse > 0.0:
            loss_mse = student_outputs["loss_mse"].mean() if self.multi_gpu else student_outputs["loss_mse"]
            loss_mse += dual_student_outputs["loss_mse"].mean() if self.multi_gpu else dual_student_outputs["loss_mse"]
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            loss_cos = student_outputs["loss_cos"].mean() if self.multi_gpu else student_outputs["loss_cos"]
            loss_cos += dual_student_outputs["loss_cos"].mean() if self.multi_gpu else dual_student_outputs["loss_cos"]
            loss += self.alpha_cos * loss_cos
            
       # we need to get causal distillation loss!
        dual_counterfactual_activations_student = get_activation_at(
            self.student,
            dual_input_ids, # this is different!
            dual_attention_mask, # this is different!
            variable_names=student_variable_names
        )
        counterfactual_activations_student = get_activation_at(
            self.student,
            input_ids, # this is different!
            attention_mask, # this is different!
            variable_names=student_variable_names
        )
        # dual on main.
        counterfactual_outputs_student = self.student(
            input_ids=input_ids, # this is different!
            attention_mask=attention_mask, # this is different!
            # interchange.
            interchanged_variables=dual_counterfactual_activations_student,
            variable_names=student_interchanged_variables_mapping,
            sampled_interchange_position=sampled_interchange_position,
            # loss.
            t_logits=t_logits,
            t_hidden_states=t_hidden_states,
            causal_t_logits=causal_t_logits,
            causal_t_hidden_states=causal_t_hidden_states,
            s_logits=s_logits,
            s_hidden_states=s_hidden_states,
            temperature=self.temperature,
            restrict_ce_to_mask=self.params.restrict_ce_to_mask,
        )
        # main on dual.
        dual_counterfactual_outputs_student = self.student(
            input_ids=dual_input_ids, # this is different!
            attention_mask=dual_attention_mask, # this is different!
            # interchange.
            interchanged_variables=dual_counterfactual_activations_student,
            variable_names=student_interchanged_variables_mapping,
            sampled_interchange_position=sampled_interchange_position,
            # loss.
            t_logits=dual_t_logits,
            t_hidden_states=dual_t_hidden_states,
            causal_t_logits=dual_causal_t_logits,
            causal_t_hidden_states=dual_causal_t_hidden_states,
            s_logits=dual_s_logits,
            s_hidden_states=dual_s_hidden_states,
            temperature=self.temperature,
            restrict_ce_to_mask=self.params.restrict_ce_to_mask,
        )
        
        # sanity check.
        assert "loss_ce" not in counterfactual_outputs_student and "loss_ce" not in dual_counterfactual_outputs_student
        assert "loss_mlm" not in counterfactual_outputs_student and "loss_mlm" not in dual_counterfactual_outputs_student
        assert "loss_clm" not in counterfactual_outputs_student and "loss_clm" not in dual_counterfactual_outputs_student
        assert "loss_mse" not in counterfactual_outputs_student and "loss_mse" not in dual_counterfactual_outputs_student
        assert "loss_cos" not in counterfactual_outputs_student and "loss_cos" not in dual_counterfactual_outputs_student
        causal_loss_ce = counterfactual_outputs_student["causal_loss_ce"].mean() if self.multi_gpu else counterfactual_outputs_student["causal_loss_ce"]
        causal_loss_ce += dual_counterfactual_outputs_student["causal_loss_ce"].mean() if self.multi_gpu else dual_counterfactual_outputs_student["causal_loss_ce"]
        teacher_interchange_efficacy = \
            counterfactual_outputs_student["teacher_interchange_efficacy"].mean() if self.multi_gpu else counterfactual_outputs_student["teacher_interchange_efficacy"]
        student_interchange_efficacy = \
            counterfactual_outputs_student["student_interchange_efficacy"].mean() if self.multi_gpu else counterfactual_outputs_student["student_interchange_efficacy"]
        teacher_interchange_efficacy += \
            dual_counterfactual_outputs_student["teacher_interchange_efficacy"].mean() if self.multi_gpu else dual_counterfactual_outputs_student["teacher_interchange_efficacy"]
        student_interchange_efficacy += \
            dual_counterfactual_outputs_student["student_interchange_efficacy"].mean() if self.multi_gpu else dual_counterfactual_outputs_student["student_interchange_efficacy"]
        if self.alpha_causal > 0.0:
            loss += self.alpha_causal * causal_loss_ce
                
        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.0:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.0:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()
        # optional recording of the value.
        self.last_loss_causal_ce = causal_loss_ce.item()
        # record efficacy of the interchange.
        self.last_teacher_interchange_efficacy = teacher_interchange_efficacy.item()
        self.last_student_interchange_efficacy = student_interchange_efficacy.item()
            
        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss, skip_update_iter=False):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        
        """
        In case where we want to do two mini-steps for dual on main interchange,
        and main on dual interchange (including normal objectives), we want to
        skip the iter update, so the gradients are accumulated within the step
        which includes gradients from two mini-steps.
        """
        self.iter(skip_update_iter=skip_update_iter)

        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
            else:
                nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self, skip_update_iter=False):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        
        if not skip_update_iter:
            self.n_iter += 1
            self.n_total_iter += 1
            if self.n_total_iter % self.params.checkpoint_interval == 0:
                self.save_checkpoint()
        
        """
        Logging is not affected by the flag skip_update_iter.
        We want to log crossway effects, and losses should be
        in the same magnitude.
        """
        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return
        
        if not self.is_wandb:
            return

        wandb.log(
            {
                "train/cum_avg_loss_epoch": self.total_loss_epoch / self.n_iter, 
                "train/loss": self.last_loss, 
                "train/loss_ce": self.last_loss_ce, 
            }, 
            step=self.n_total_iter
        )
        
        if self.alpha_mlm > 0.0:
            wandb.log(
                {"train/loss_mlm": self.last_loss_mlm}, 
                step=self.n_total_iter
            )
        if self.alpha_clm > 0.0:
            wandb.log(
                {"train/loss_clm": self.last_loss_clm}, 
                step=self.n_total_iter
            )
        if self.alpha_mse > 0.0:
            wandb.log(
                {"train/loss_mse": self.last_loss_mse}, 
                step=self.n_total_iter
            )
        if self.alpha_cos > 0.0:
            wandb.log(
                {"train/loss_cos": self.last_loss_cos}, 
                step=self.n_total_iter
            )

        wandb.log(
            {
                "train/loss_causal_ce": self.last_loss_causal_ce,
                "train/teacher_interchange_efficacy": self.last_teacher_interchange_efficacy,
                "train/student_interchange_efficacy": self.last_student_interchange_efficacy,
            }, 
            step=self.n_total_iter
        )
        
        wandb.log(
            {
                "train/learning_rate": self.scheduler.get_lr()[0],
                "train/memory_usage": psutil.virtual_memory()._asdict()["used"] / 1_000_000,
                "train/speed": time.time() - self.last_log,
            }, 
            step=self.n_total_iter
        )

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")

        if self.is_master:
            self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
            if self.is_wandb:
                wandb.log(
                    {
                        "epoch/loss": self.total_loss_epoch / self.n_iter, 
                        'epoch': self.epoch
                    }
                )

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))

In [6]:
def sanity_checks(args):
    """
    A bunch of args sanity checks to perform even starting...
    """
    assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
    assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
    if args.mlm:
        assert os.path.isfile(args.token_counts)
        assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
    else:
        assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])

    assert args.teacher_type == args.student_type or (
        args.student_type == "distilbert" and args.teacher_type == "bert"
    )
    assert os.path.isfile(args.student_config)
    if args.student_pretrained_weights is not None:
        assert os.path.isfile(args.student_pretrained_weights)

    if args.freeze_token_type_embds:
        assert args.student_type in ["roberta"]

    assert args.alpha_ce >= 0.0
    assert args.alpha_mlm >= 0.0
    assert args.alpha_clm >= 0.0
    assert args.alpha_mse >= 0.0
    assert args.alpha_cos >= 0.0
    assert args.alpha_causal >= 0.0
    assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos + args.alpha_causal > 0.0


def freeze_pos_embeddings(student, args):
    if args.student_type == "roberta":
        student.roberta.embeddings.position_embeddings.weight.requires_grad = False
    elif args.student_type == "gpt2":
        student.transformer.wpe.weight.requires_grad = False


def freeze_token_type_embeddings(student, args):
    if args.student_type == "roberta":
        student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False


def prepare_distiller(args):

    # ARGS #
    init_gpu_params(args)
    set_seed(args)
    # More validations #
    if args.parallel_crossway:
        assert args.include_crossway
    if not args.include_crossway:
        assert not args.parallel_crossway
    if args.is_master:
        if os.path.exists(args.dump_path):
            if not args.force:
                raise ValueError(
                    f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
                    "Use `--force` if you want to overwrite it"
                )
            else:
                shutil.rmtree(args.dump_path)

        if not os.path.exists(args.dump_path):
            os.makedirs(args.dump_path)
        logger.info(f"Experiment will be dumped and logged in {args.dump_path}")

        # SAVE PARAMS #
        logger.info(f"Param: {args}")
        with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
            json.dump(vars(args), f, indent=4)
        git_log(args.dump_path)

    student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
    teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]

    # TOKENIZER #
    tokenizer = teacher_tokenizer_class.from_pretrained(
        args.teacher_name,
        cache_dir=args.cache_dir
    )
    special_tok_ids = {}
    for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
        idx = tokenizer.all_special_tokens.index(tok_symbol)
        special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
    logger.info(f"Special tokens {special_tok_ids}")
    args.special_tok_ids = special_tok_ids
    args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]

    # DATA LOADER #
    logger.info(f"Loading data from {args.data_file}")
    with open(args.data_file, "rb") as fp:
        data = pickle.load(fp)

    if args.mlm:
        logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
        with open(args.token_counts, "rb") as fp:
            counts = pickle.load(fp)

        token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
        for idx in special_tok_ids.values():
            token_probs[idx] = 0.0  # do not predict special tokens
        token_probs = torch.from_numpy(token_probs)
    else:
        token_probs = None

    train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
    logger.info("Data loader created.")

    # STUDENT #
    logger.info(f"Loading student config from {args.student_config}")
    stu_architecture_config = student_config_class.from_pretrained(
        args.student_config,
        cache_dir=args.cache_dir
    )
    stu_architecture_config.output_hidden_states = True

    if args.student_pretrained_weights is not None:
        logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
        student = student_model_class.from_pretrained(
            args.student_pretrained_weights, config=stu_architecture_config,
            cache_dir=args.cache_dir
        )
    else:
        student = student_model_class(stu_architecture_config)
    logger.info("Student loaded.")

    # TEACHER #
    teacher = teacher_model_class.from_pretrained(
        args.teacher_name, output_hidden_states=True, 
        cache_dir=args.cache_dir
    )

    logger.info(f"Teacher loaded from {args.teacher_name}.")

    # FREEZING #
    if args.freeze_pos_embs:
        freeze_pos_embeddings(student, args)
    if args.freeze_token_type_embds:
        freeze_token_type_embeddings(student, args)

    # SANITY CHECKS #
    assert student.config.vocab_size == teacher.config.vocab_size
    assert student.config.hidden_size == teacher.config.hidden_size
    assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
    if args.mlm:
        assert token_probs.size(0) == stu_architecture_config.vocab_size

    # DISTILLER #
    torch.cuda.empty_cache()
    distiller = CausalDistiller(
        params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
    )
    logger.info("Distiller initialization done.")
    return distiller

In [7]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")

    parser.add_argument(
        "--dump_path", type=str, help="The output directory (log, checkpoints, parameters, etc.)"
    )
    parser.add_argument(
        "--data_file",
        type=str,
        help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        help="You need to give some cache dir.",
        default="./distill_cache/"
    )
    parser.add_argument(
        "--student_type",
        type=str,
        choices=["distilbert", "roberta", "gpt2"],
        help="The student type (DistilBERT, RoBERTa).",
    )
    parser.add_argument("--student_config", type=str, help="Path to the student configuration.")
    parser.add_argument(
        "--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
    )

    parser.add_argument(
        "--teacher_type", choices=["bert", "roberta", "gpt2"], help="Teacher type (BERT, RoBERTa)."
    )
    parser.add_argument("--teacher_name", type=str, help="The teacher model.")

    parser.add_argument(
        "--neuron_mapping",
        type=str,
        help="Predefined neuron mapping for the interchange experiment.",
    )
    
    parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
    parser.add_argument(
        "--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
    )
    parser.add_argument(
        "--alpha_mlm",
        default=0.0,
        type=float,
        help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
    )
    parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
    parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
    parser.add_argument(
        "--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
    )
    parser.add_argument(
        "--alpha_causal", default=0.0, type=float, help="Linear weight of the causal distillation loss. Must be >=0."
    )

    parser.add_argument(
        "--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
    )
    parser.add_argument(
        "--mlm_mask_prop",
        default=0.15,
        type=float,
        help="Proportion of tokens for which we need to make a prediction.",
    )
    
    parser.add_argument(
        "--interchange_mlm", action="store_true", help="Whehter to follow mlm to select token positions to do interchange."
    )
    parser.add_argument(
        "--interchange_prop",
        default=0.3,
        type=float,
        help="Ratio of tokens to mask for interchange interventions. 1.0 means interchange all.",
    )
    parser.add_argument(
        "--include_crossway", default=False, action="store_true", help="Whether to include crossway losses."
    )
    parser.add_argument(
        "--parallel_crossway", default=False, action="store_true", help="Whether to calculate cross losses in a single step."
    )
    
    parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
    parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
    parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
    parser.add_argument(
        "--mlm_smoothing",
        default=0.7,
        type=float,
        help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
    )
    parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")

    parser.add_argument(
        "--restrict_ce_to_mask",
        action="store_true",
        help="If true, compute the distilation loss only the [MLM] prediction distribution.",
    )
    parser.add_argument(
        "--freeze_pos_embs",
        action="store_true",
        help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
    )
    parser.add_argument(
        "--freeze_token_type_embds",
        action="store_true",
        help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
    )

    parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
    parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
    parser.add_argument(
        "--group_by_size",
        action="store_false",
        help="If true, group sequences that have similar length into the same batch. Default is true.",
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=50,
        help="Gradient accumulation for larger training batches.",
    )
    parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
    parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
    parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
    parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
    parser.add_argument("--seed", type=int, default=56, help="Random seed")

    parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
    parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
    
    parser.add_argument(
        "--is_wandb",
        action="store_true",
        default=False,
        help="If true, we will log everything to wandb that is logged in currently.",
    )
    parser.add_argument("--run_name", type=str, help="Name of this run.")
    
    try:
        get_ipython().run_line_magic('matplotlib', 'inline')
        parser.set_defaults(
            # Exp management:
            student_type="distilbert",
            teacher_type="bert",
            mlm=True,
            alpha_ce=0.25,
            alpha_mlm=0.25,
            alpha_cos=0.25,
            alpha_clm=0.0,
            alpha_causal=0.25,
            token_counts="./demo_data/binarized_text.train.token_counts.bert-base-uncased.pickle",
            student_config="./training_configs/distilbert-base-uncased.json",
            dump_path="./arxiv_results/",
            teacher_name="bert-base-uncased",
            force=True,
            data_file="./demo_data/binarized_text.train.bert-base-uncased.pickle",
            n_gpu=0,
            is_wandb=False,
            log_interval=10,
            neuron_mapping="./training_configs/single_multilayer.nm",
            local_rank=-1,
            interchange_prop=0.3,
            batch_size=5,
            gradient_accumulation_steps=50,
            include_crossway=False,
            parallel_crossway=False,
        )
        print("Prelude: running in notebook for testing only.")
        args = parser.parse_args([])
    except:
        print("Prelude: running with command line.")
        args = parser.parse_args()
    
    # config the runname here and overwrite.
    data_name = args.data_file.split("/")[-2]
    neuron_mapping = args.neuron_mapping.split("/")[-1].split(".")[0]
    run_name = f"s_{args.student_type}_t_{args.teacher_type}_data_{data_name}_seed_{args.seed}_mlm_{args.mlm}_ce_{args.alpha_ce}_mlm_{args.alpha_mlm}_cos_{args.alpha_cos}_causal_{args.alpha_causal}_nm_{neuron_mapping}_crossway_{args.include_crossway}"
    args.run_name = run_name
    args.dump_path = os.path.join(args.dump_path, args.run_name)
    sanity_checks(args)
    # for arXiv, we enforce the following settings.
    assert not args.include_crossway
    assert not args.parallel_crossway
    
    distiller = prepare_distiller(args)
    
    # distiller.train()
    # logger.info("Hey Zen: Let's go get some drinks.")

11/15/2021 11:57:57 - INFO - utils - PID: 51231 -  Experiment will be dumped and logged in ./arxiv_results/s_distilbert_t_bert_data_demo_data_seed_56_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer_crossway_False
11/15/2021 11:57:57 - INFO - utils - PID: 51231 -  Param: Namespace(adam_epsilon=1e-06, alpha_causal=0.25, alpha_ce=0.25, alpha_clm=0.0, alpha_cos=0.25, alpha_mlm=0.25, alpha_mse=0.0, batch_size=5, cache_dir='./distill_cache/', checkpoint_interval=4000, data_file='./demo_data/binarized_text.train.bert-base-uncased.pickle', dump_path='./arxiv_results/s_distilbert_t_bert_data_demo_data_seed_56_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer_crossway_False', force=True, fp16=False, fp16_opt_level='O1', freeze_pos_embs=False, freeze_token_type_embds=False, gradient_accumulation_steps=50, group_by_size=True, include_crossway=False, initializer_range=0.02, interchange_mlm=False, interchange_prop=0.3, is_master=True, is_wandb=False, learning_ra

Prelude: running in notebook for testing only.


11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Special tokens {'unk_token': 100, 'sep_token': 102, 'pad_token': 0, 'cls_token': 101, 'mask_token': 103}
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Loading data from ./demo_data/binarized_text.train.bert-base-uncased.pickle
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Loading token counts from ./demo_data/binarized_text.train.token_counts.bert-base-uncased.pickle (already pre-computed)
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Splitting 0 too long sequences.
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Remove 542 too short (<=11 tokens) sequences.
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Remove 0 sequences with a high level of unknown tokens (50%).
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Preparing causal batch.
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  458 sequences
11/15/2021 11:58:00 - INFO - utils - PID: 51231 -  Data loader created.
11/15/2021 11:58:00 - INFO - utils - PID:

In [8]:
distiller

<__main__.CausalDistiller at 0x7fd4e1614990>