In [1]:
import logging
import os
import random
import pickle
import time
import psutil

import numpy as np
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from tqdm import tqdm, trange
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from BERT.pytorch_pretrained_bert.modeling import BertConfig
from BERT.pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from BERT.pytorch_pretrained_bert.tokenization import BertTokenizer

from distiller import TaskSpecificDistiller
from causal_distiller import TaskSpecificCausalDistiller

from src.argument_parser import default_parser, get_predefine_argv, complete_argument
from src.nli_data_processing import processors, output_modes
from src.data_processing import init_model, get_task_dataloader
from src.modeling import BertForSequenceClassificationEncoder, FCClassifierForSequenceClassification, FullFCClassifierForSequenceClassification
from src.utils import load_model, count_parameters, eval_model_dataloader_nli, eval_model_dataloader
from src.KD_loss import distillation_loss, patience_loss
from envs import HOME_DATA_FOLDER

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
#########################################################################
# Prepare Parser
##########################################################################
parser = default_parser()
DEBUG = True
if DEBUG:
    logger.info("IN DEBUG MODE")
    # run simple fune-tuning *teacher* by uncommenting below cmd
    # argv = get_predefine_argv('glue', 'RTE', 'finetune_teacher')

    # run simple fune-tuning *student* by uncommenting below cmd
    # argv = get_predefine_argv('glue', 'RTE', 'finetune_student')

    # run vanilla KD by uncommenting below cmd
    # argv = get_predefine_argv('glue', 'RTE', 'kd')

    # run Patient Teacher by uncommenting below cmd
    argv = get_predefine_argv('glue', 'SST-2', 'kd.cls')
    try:
        args = parser.parse_args(argv)
    except NameError:
        raise ValueError('please uncomment one of option above to start training')
    args.max_training_examples = 1000
else:
    logger.info("IN CMD MODE")
    args = parser.parse_args()
args = complete_argument(args, is_debug=DEBUG)

02/11/2022 01:38:10 - INFO - __main__ -   IN DEBUG MODE
02/11/2022 01:38:10 - INFO - src.argument_parser -   encoder checkpoint not provided, use pre-trained at /dfs/user/wuzhengx/workspace/Causal-Distill-XXS/data/models/pretrained/bert-base-uncased/pytorch_model.bin instead
02/11/2022 01:38:10 - INFO - src.argument_parser -   encoder checkpoint not provided, use default directory for fine-tuned model at /dfs/user/wuzhengx/workspace/Causal-Distill-XXS/data/models/finetuned/bert-base-uncased/SST-2/pytorch_model.bin instead
02/11/2022 01:38:10 - INFO - src.argument_parser -   encoder checkpoint not provided, use default directory for fine-tuned model at /dfs/user/wuzhengx/workspace/Causal-Distill-XXS/data/models/finetuned/bert-base-uncased/SST-2/pytorch_model.bin instead
02/11/2022 01:38:10 - INFO - src.argument_parser -   folder exist but empty, use it as output
02/11/2022 01:38:10 - INFO - src.argument_parser -   device: cpu n_gpu: 0, 16-bits training: False
02/11/2022 01:38:10 - INFO 

In [3]:
args.raw_data_dir = os.path.join(HOME_DATA_FOLDER, 'data_raw', args.task_name)
args.feat_data_dir = os.path.join(HOME_DATA_FOLDER, 'data_feat', args.task_name)

args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
logger.info('actual batch size on all GPU = %d' % args.train_batch_size)
device, n_gpu = args.device, args.n_gpu

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

logger.info('Input Argument Information')
args_dict = vars(args)
for a in args_dict:
    logger.info('%-28s  %s' % (a, args_dict[a]))

#########################################################################
# Prepare  Data
##########################################################################
task_name = args.task_name.lower()

if task_name not in processors and 'race' not in task_name:
    raise ValueError("Task not found: %s" % (task_name))

if 'race' in task_name:
    pass
else:
    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)

02/11/2022 01:38:10 - INFO - __main__ -   actual batch size on all GPU = 32
02/11/2022 01:38:10 - INFO - __main__ -   Input Argument Information
02/11/2022 01:38:10 - INFO - __main__ -   task_name                     SST-2
02/11/2022 01:38:10 - INFO - __main__ -   run_name                      kd.cls.True_SST-2_nlayer.6_lr.1e-05_T.10.0_alpha.0.7_beta.500.0_bs.32
02/11/2022 01:38:10 - INFO - __main__ -   output_dir                    /dfs/user/wuzhengx/workspace/Causal-Distill-XXS/data/outputs/KD/SST-2/teacher_12layer/kd.cls.True_SST-2_nlayer.6_lr.1e-05_T.10.0_alpha.0.7_beta.500.0_bs.32-run-1
02/11/2022 01:38:10 - INFO - __main__ -   log_every_step                1
02/11/2022 01:38:10 - INFO - __main__ -   log_interval                  500
02/11/2022 01:38:10 - INFO - __main__ -   checkpoint_interval           4000
02/11/2022 01:38:10 - INFO - __main__ -   max_seq_length                128
02/11/2022 01:38:10 - INFO - __main__ -   max_training_examples         1000
02/11/2022 01:38:10 -

In [4]:
if args.do_train:
    train_sampler = SequentialSampler if DEBUG else RandomSampler
    read_set = 'train'
    logger.info('skipping loading teacher\'s predictoin, we calculate this on-the-fly')
    train_examples, train_dataloader, _ = get_task_dataloader(task_name, read_set, tokenizer, args, SequentialSampler,
                                                              batch_size=args.train_batch_size)
    num_train_optimization_steps = int(len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    args.num_train_optimization_steps = num_train_optimization_steps

    # Run prediction for full data
    eval_examples, eval_dataloader, eval_label_ids = get_task_dataloader(task_name, 'dev', tokenizer, args, SequentialSampler, batch_size=args.eval_batch_size)
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)


02/11/2022 01:38:10 - INFO - __main__ -   skipping loading teacher's predictoin, we calculate this on-the-fly
02/11/2022 01:38:11 - INFO - src.nli_data_processing -   Writing example 0 of 1000
02/11/2022 01:38:11 - INFO - __main__ -   ***** Running training *****
02/11/2022 01:38:11 - INFO - __main__ -     Num examples = 1000
02/11/2022 01:38:11 - INFO - __main__ -     Batch size = 32
02/11/2022 01:38:11 - INFO - __main__ -     Num steps = 124
02/11/2022 01:38:11 - INFO - src.nli_data_processing -   Writing example 0 of 872
02/11/2022 01:38:11 - INFO - __main__ -   ***** Running evaluation *****
02/11/2022 01:38:11 - INFO - __main__ -     Num examples = 872
02/11/2022 01:38:11 - INFO - __main__ -     Batch size = 32


In [5]:
#########################################################################
# Prepare model
#########################################################################
student_config = BertConfig(os.path.join(args.bert_model, 'bert_config.json'))
teacher_config = BertConfig(os.path.join(args.bert_model, 'bert_config.json'))
if args.kd_model.lower() in ['kd', 'kd.cls']:
    logger.info('using normal Knowledge Distillation')
    output_all_layers = args.kd_model.lower() == 'kd.cls'
    logger.info('*' * 77)
    logger.info("Loading the student model...")
    logger.info('*' * 77)
    student_encoder, student_classifier = init_model(
        task_name, output_all_layers, 
        args.student_hidden_layers, student_config,
    )

    n_student_layer = len(student_encoder.bert.encoder.layer)
    student_encoder = load_model(
        student_encoder, args.encoder_checkpoint_student, args, 'student', 
        verbose=True, DEBUG=False,
    )
    logger.info('*' * 77)
    student_classifier = load_model(
        student_classifier, args.cls_checkpoint_student, args, 'classifier', 
        verbose=True, DEBUG=False,
    )
    
    logger.info('*' * 77)
    logger.info("Loading the teacher model...")
    logger.info('*' * 77)
    # since we also calculate teacher's output on-fly, we need to load the teacher model as well.
    # note that, we assume teacher model is pre-trained already.
    teacher_encoder, teacher_classifier = init_model(
        task_name, output_all_layers, 
        teacher_config.num_hidden_layers, teacher_config,
    )
    
    n_teacher_layer = len(teacher_encoder.bert.encoder.layer)
    teacher_encoder = load_model(
        teacher_encoder, args.encoder_checkpoint_teacher, args, 'student', 
        verbose=True, DEBUG=False,
    )
    logger.info('*' * 77)
    teacher_classifier = load_model(
        teacher_classifier, args.cls_checkpoint_teacher, args, 'classifier', 
        verbose=True, DEBUG=False,
    )

else:
    # originally, the codebase supports kd.full, but that is never used.
    raise ValueError('%s KD not found, please use kd or kd.cls' % args.kd)

n_param_student = count_parameters(student_encoder) + count_parameters(student_classifier)
logger.info('number of layers in student model = %d' % n_student_layer)
logger.info('num parameters in student model are %d and %d' % (count_parameters(student_encoder), count_parameters(student_classifier)))

02/11/2022 01:38:11 - INFO - __main__ -   using normal Knowledge Distillation
02/11/2022 01:38:11 - INFO - __main__ -   *****************************************************************************
02/11/2022 01:38:11 - INFO - __main__ -   Loading the student model...
02/11/2022 01:38:11 - INFO - __main__ -   *****************************************************************************
02/11/2022 01:38:11 - INFO - src.nli_data_processing -   predicting for SST-2
02/11/2022 01:38:11 - INFO - src.modeling -   num hidden layer is set as 6
02/11/2022 01:38:11 - INFO - src.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

02/11/2022 01:38:13 - INFO - src.utils -   loading BertForSequenceClassificationEncode

In [20]:
import logging
import os
import random
import pickle
import time
import psutil
import wandb

import numpy as np
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from tqdm import tqdm, trange
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

from BERT.pytorch_pretrained_bert.modeling import BertConfig
from BERT.pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from BERT.pytorch_pretrained_bert.tokenization import BertTokenizer

from src.argument_parser import default_parser, get_predefine_argv, complete_argument
from src.nli_data_processing import processors, output_modes
from src.data_processing import init_model, get_task_dataloader
from src.modeling import BertForSequenceClassificationEncoder, FCClassifierForSequenceClassification, FullFCClassifierForSequenceClassification
from src.utils import load_model, count_parameters, eval_model_dataloader_nli, eval_model_dataloader
from src.KD_loss import distillation_loss, patience_loss
from envs import HOME_DATA_FOLDER

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

class TaskSpecificCausalDistiller:
    def __init__(
        self, params, 
        train_dataset, eval_dataset, 
        eval_label_ids, num_labels, output_mode,
        student_encoder: nn.Module, student_classifier: nn.Module,
        teacher_encoder: nn.Module, teacher_classifier: nn.Module,
    ):
        if params.is_wandb:
            run = wandb.init(
                project=params.wandb_metadata.split(":")[-1], 
                entity=params.wandb_metadata.split(":")[0],
                name=params.run_name,
            )
            wandb.config.update(params)
        self.is_wandb = params.is_wandb
        logger.info("Initializing Normal Distiller (Task Specific)")
        
        self.params = params
        
        self.output_model_file = '{}_nlayer.{}_lr.{}_T.{}.alpha.{}_beta.{}_bs.{}'.format(
            self.params.task_name, 
            self.params.student_hidden_layers,
            self.params.learning_rate,
            self.params.T, 
            self.params.alpha, 
            self.params.beta,
            self.params.train_batch_size * self.params.gradient_accumulation_steps
        )
        
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.eval_label_ids = eval_label_ids
        self.num_labels = num_labels
        self.output_mode = output_mode
        
        self.student_encoder = student_encoder
        self.student_classifier = student_classifier
        self.teacher_encoder = teacher_encoder
        self.teacher_classifier = teacher_classifier
        
        # common used vars
        self.fp16 = params.fp16
        self.T = params.T
        self.alpha = params.alpha
        self.beta = params.beta
        self.normalize_patience = params.normalize_patience
        self.learning_rate = params.learning_rate
        self.train_batch_size = params.train_batch_size
        self.output_dir = params.output_dir
        self.warmup_proportion = params.warmup_proportion
        self.num_train_optimization_steps = params.num_train_optimization_steps
        self.task_name = params.task_name
        self.kd_model = params.kd_model 
        self.weights = params.weights
        self.fc_layer_idx = params.fc_layer_idx
        self.n_gpu = params.n_gpu
        self.device = params.device
        self.num_train_epochs = params.num_train_epochs
        self.gradient_accumulation_steps = params.gradient_accumulation_steps
        self.loss_scale = params.loss_scale
        
        # DIITO params
        self.is_diito = params.is_diito
        self.diito_type = params.diito_type
        self.interchange_prop = params.interchange_prop
        self.interchange_max_token = params.interchange_max_token
        self.interchange_masked_token_only = params.interchange_masked_token_only
        self.interchange_consecutive_only = params.interchange_consecutive_only
        self.data_augment = params.data_augment
        
        # log to a local file
        log_train = open(os.path.join(self.output_dir, 'train_log.txt'), 'w', buffering=1)
        log_eval = open(os.path.join(self.output_dir, 'eval_log.txt'), 'w', buffering=1)
        print('epoch,global_steps,step,acc,loss,kd_loss,ce_loss,AT_loss', file=log_train)
        print('epoch,acc,loss', file=log_eval)
        log_train.close()
        log_eval.close()
    
        param_optimizer = list(
            self.student_encoder.named_parameters()
        ) + list(
            self.student_classifier.named_parameters()
        )
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        if self.fp16:
            logger.info('FP16 activate, use apex FusedAdam')
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

            self.optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=self.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if self.loss_scale == 0:
                self.optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                self.optimizer = FP16_Optimizer(optimizer, static_loss_scale=self.loss_scale)
        else:
            logger.info('FP16 is not activated, use BertAdam')
            self.optimizer = BertAdam(
                optimizer_grouped_parameters,
                lr=self.learning_rate,
                warmup=self.warmup_proportion,
                t_total=self.num_train_optimization_steps
            )
        
        # other params that report to tensorboard
        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_dl = 0
        self.last_kd_loss = 0
        self.last_ce_loss = 0 
        self.last_pt_loss = 0
        self.lr_this_step = 0
        self.last_log = 0
        
        self.acc_tr_loss = 0
        self.acc_tr_kd_loss = 0
        self.acc_tr_ce_loss = 0
        self.acc_tr_pt_loss = 0
        self.acc_tr_acc = 0
        
        self.tr_loss = 0
        self.tr_kd_loss = 0
        self.tr_ce_loss = 0
        self.tr_pt_loss = 0
        self.tr_acc = 0
        
        # DIITO related params that report to tensorboard
    
    def prepare_batch(self, input_ids, input_mask, segment_ids, label_ids, ):
        if self.is_diito:
            dual_input_ids = input_ids.clone()
            dual_input_mask = input_mask.clone()
            dual_segment_ids = segment_ids.clone()
            dual_label_ids = label_ids.clone()
            causal_sort_index = [i for i in range(dual_input_ids.shape[0])]
            random.shuffle(causal_sort_index)
            dual_input_ids = dual_input_ids[causal_sort_index]
            dual_input_mask = dual_input_mask[causal_sort_index]
            dual_segment_ids = dual_segment_ids[causal_sort_index]
            dual_label_ids = dual_label_ids[causal_sort_index]
            return input_ids, input_mask, segment_ids, label_ids, \
                dual_input_ids, dual_input_mask, dual_segment_ids, dual_label_ids
        else:
            return input_ids, input_mask, segment_ids, label_ids
    
    def train(self):
        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        self.student_encoder.train()
        self.student_classifier.train()
        self.teacher_encoder.eval()
        self.teacher_classifier.eval()
        self.last_log = time.time()
        
        for epoch in trange(int(self.num_train_epochs), desc="Epoch"):
            tr_loss, tr_ce_loss, tr_kd_loss, tr_acc = 0, 0, 0, 0
            nb_tr_examples, nb_tr_steps = 0, 0
            
            iter_bar = tqdm(self.train_dataset, desc="-Iter", disable=False)
            for batch in iter_bar:
                batch = tuple(t.to(self.device) for t in batch)
                # teascher patient is on-the-fly, we can skip the logic for different batch format.
                prepared_batch = self.prepare_batch(
                    *batch,
                )
                if self.is_diito:
                    self.step_diito(
                        *prepared_batch
                    )
                else:
                    self.step(
                        *prepared_batch
                    )
                iter_bar.update()
                iter_bar.set_postfix(
                    {
                        "Last_loss": f"{self.last_loss:.2f}", 
                        "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}", 
                    }
                )
            iter_bar.close()

            logger.info(f"--- Ending epoch {self.epoch}/{self.num_train_epochs-1}")
            self.end_epoch()

        logger.info("Save very last checkpoint as `pytorch_model.bin`.")
        self.save_checkpoint(checkpoint_name="pytorch_model.bin")
        logger.info("Training is finished")
    
    def prepare_interchange_mask(
        self,
        lengths, dual_lengths,
        pred_mask, dual_pred_mask,
    ):        
        # params
        interchange_prop = self.interchange_prop
        interchange_max_token = self.interchange_max_token # if -1 then we don't restrict on this.
        interchange_masked_token_only = self.interchange_masked_token_only
        interchange_consecutive_only = self.interchange_consecutive_only
        
        interchange_mask = torch.zeros_like(pred_mask, dtype=torch.bool)
        dual_interchange_mask = torch.zeros_like(dual_pred_mask, dtype=torch.bool)

        batch_size, max_seq_len = pred_mask.shape[0], pred_mask.shape[1]
        _, dual_max_seq_len = dual_pred_mask.shape[0], dual_pred_mask.shape[1]
        interchange_position = []
        for i in range(0, batch_size):
            min_len = min(lengths[i].tolist(), dual_lengths[i].tolist())
            if interchange_consecutive_only:
                if interchange_max_token != -1:
                    interchange_count = min(interchange_max_token, int(min_len*interchange_prop))
                else:
                    interchange_count = int(min_len*interchange_prop)
                start_index = random.randint(0, lengths[i].tolist()-interchange_count)
                end_index = start_index + interchange_count
                dual_start_index = random.randint(0, dual_lengths[i].tolist()-interchange_count)
                dual_end_index = dual_start_index + interchange_count
                interchange_mask[i][start_index:end_index] = 1
                dual_interchange_mask[i][dual_start_index:dual_end_index] = 1
            else:
                # we follow these steps to sample the position:
                # 1. sample positions in the main example
                # 2. get the actual sampled positions
                # 3. sample accordingly from the dual example
                if interchange_masked_token_only:
                    # a corner case we need to consider is that the masked token
                    # numbers may differ across two examples.
                    interchange_count = pred_mask[i].sum()
                    if interchange_count > dual_lengths[i]:
                        # not likely, but we need to handle this.
                        interchange_count = dual_lengths[i]
                    interchange_position = pred_mask[i].nonzero().view(-1).tolist()
                    interchange_position = random.sample(interchange_position, interchange_count)
                    interchange_mask[i][interchange_position] = 1
                    dual_interchange_position = random.sample(range(dual_max_seq_len), interchange_count)
                    dual_interchange_mask[i][dual_interchange_position] = 1
                else:
                    if interchange_max_token != -1:
                        interchange_count = min(interchange_max_token, int(min_len*interchange_prop))
                    else:
                        interchange_count = int(min_len*interchange_prop)
                    interchange_position = random.sample(range(max_seq_len), interchange_count)
                    interchange_mask[i][interchange_position] = 1
                    dual_interchange_position = random.sample(range(dual_max_seq_len), interchange_count)
                    dual_interchange_mask[i][dual_interchange_position] = 1

        # sanity checks
        assert interchange_mask.long().sum(dim=-1).tolist() == \
                dual_interchange_mask.long().sum(dim=-1).tolist()

        return interchange_mask, dual_interchange_mask
    
    def step_diito(
        self,
        input_ids,
        input_mask,
        segment_ids,
        label_ids,
        dual_input_ids,
        dual_input_mask,
        dual_segment_ids,
        dual_label_ids,
    ):
        interchange_mask, dual_interchange_mask = self.prepare_interchange_mask(
            input_mask.sum(dim=-1), dual_input_mask.sum(dim=-1),
            input_mask, dual_input_mask,
        )
        print(interchange_mask.sum(dim=-1))
        print(dual_interchange_mask.sum(dim=-1))
        # first, we simply prepare interchange positions.
        pass
    
    def step(
        self,
        input_ids,
        input_mask,
        segment_ids,
        label_ids,
    ):
        # teacher no_grad() forward pass.
        with torch.no_grad():
            if self.alpha == 0:
                teacher_pred, teacher_patience = None, None
            else:
                # define a new function to compute loss values for both output_modes
                full_output_teacher, pooled_output_teacher = self.teacher_encoder(
                    input_ids, segment_ids, input_mask
                )
                if self.kd_model.lower() in['kd', 'kd.cls']:
                    teacher_pred = self.teacher_classifier(pooled_output_teacher)
                    if self.kd_model.lower() == 'kd.cls':
                        teacher_patience = torch.stack(full_output_teacher[:-1]).transpose(0, 1)
                        if self.fp16:
                            teacher_patience = teacher_patience.half()
                        layer_index = [int(i) for i in self.fc_layer_idx.split(',')]
                        teacher_patience = torch.stack(
                            [torch.FloatTensor(teacher_patience[:,int(i)]) for i in layer_index]
                        ).transpose(0, 1)
                    else:
                        teacher_patience = None
                else:
                    raise ValueError(f'{self.kd_model} not implemented yet')
                if self.fp16:
                    teacher_pred = teacher_pred.half()
            
        # student with_grad() forward pass.
        full_output_student, pooled_output_student = self.student_encoder(
            input_ids, segment_ids, input_mask
        )
        if self.kd_model.lower() in['kd', 'kd.cls']:
            logits_pred_student = self.student_classifier(
                pooled_output_student
            )
            if self.kd_model.lower() == 'kd.cls':
                student_patience = torch.stack(full_output_student[:-1]).transpose(0, 1)
            else:
                student_patience = None
        else:
            raise ValueError(f'{self.kd_model} not implemented yet')

        # calculate loss
        loss_dl, kd_loss, ce_loss = distillation_loss(
            logits_pred_student, label_ids, teacher_pred, T=self.T, alpha=self.alpha
        )
        if self.beta > 0:
            if student_patience.shape[0] != input_ids.shape[0]:
                # For RACE
                n_layer = student_patience.shape[1]
                student_patience = student_patience.transpose(0, 1).contiguous().view(
                    n_layer, input_ids.shape[0], -1
                ).transpose(0,1)
            pt_loss = self.beta * patience_loss(
                teacher_patience, student_patience, 
                self.normalize_patience
            )
            loss = loss_dl + pt_loss
        else:
            pt_loss = torch.tensor(0.0)
            loss = loss_dl
        if self.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu.
        
        # bookkeeping?
        self.last_loss_dl = 0
        self.last_kd_loss = 0
        self.last_ce_loss = 0 
        self.last_pt_loss = 0
        
        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_dl = loss_dl.mean().item() if self.n_gpu > 0 else loss_dl.item()
        self.last_kd_loss = kd_loss.mean().item() if self.n_gpu > 0 else kd_loss.item()
        self.last_ce_loss = ce_loss.mean().item() if self.n_gpu > 0 else ce_loss.item()
        self.last_pt_loss = pt_loss.mean().item() if self.n_gpu > 0 else pt_loss.item()
        
        n_sample = input_ids.shape[0]
        self.acc_tr_loss += self.last_loss * n_sample
        self.acc_tr_kd_loss += self.last_kd_loss * n_sample
        self.acc_tr_ce_loss += self.last_ce_loss * n_sample
        self.acc_tr_pt_loss = self.last_pt_loss * n_sample
        pred_cls = logits_pred_student.data.max(1)[1]
        self.acc_tr_acc += pred_cls.eq(label_ids).sum().cpu().item()
        self.n_sequences_epoch += n_sample
        
        self.tr_loss = self.acc_tr_loss / self.n_sequences_epoch
        self.tr_kd_loss = self.acc_tr_kd_loss / self.n_sequences_epoch
        self.tr_ce_loss = self.acc_tr_ce_loss / self.n_sequences_epoch
        self.tr_pt_loss = self.acc_tr_pt_loss / self.n_sequences_epoch
        self.tr_acc = self.acc_tr_acc / self.n_sequences_epoch
              
        self.optimize(loss)
            
    def optimize(self, loss):
        if self.gradient_accumulation_steps > 1:
            loss = loss / self.gradient_accumulation_steps
        
        # backward()
        if self.fp16:
            self.optimizer.backward(loss)
        else:
            loss.backward()

        self.iter()

        if self.n_iter % self.gradient_accumulation_steps == 0:
            if self.fp16:
                self.lr_this_step = self.learning_rate * warmup_linear(
                    self.n_total_iter / self.num_train_optimization_steps,
                    self.warmup_proportion
                )
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr_this_step
            self.optimizer.step()
            self.optimizer.zero_grad()
    
    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        
        self.n_iter += 1
        self.n_total_iter += 1
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            pass
            # you can uncomment this line, if you really have checkpoints.
            # self.save_checkpoint()
        
        """
        Logging is not affected by the flag skip_update_iter.
        We want to log crossway effects, and losses should be
        in the same magnitude.
        """
        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """

        log_train = open(os.path.join(self.output_dir, 'train_log.txt'), 'a', buffering=1)
        print('{},{},{},{},{},{},{},{}'.format(
                self.epoch+1, self.n_total_iter, self.n_iter, 
                self.tr_acc,
                self.tr_loss, 
                self.tr_kd_loss,
                self.tr_ce_loss, 
                self.tr_pt_loss
            ),
            file=log_train
        )
        log_train.close()
        
        if not self.is_wandb:
            pass # log to local logging file?
        else:    
            wandb.log(
                {
                    "train/loss": self.last_loss, 
                    "train/loss_dl": self.last_loss_dl, 
                    "train/kd_loss": self.last_kd_loss, 
                    "train/ce_loss": self.last_ce_loss, 
                    "train/pt_loss": self.last_pt_loss, 
                    
                    "train/epoch_loss": self.tr_loss, 
                    "train/epoch_kd_loss": self.tr_kd_loss, 
                    "train/epoch_ce_loss": self.tr_ce_loss, 
                    "train/epoch_pt_loss": self.tr_pt_loss, 
                    "train/epoch_tr_acc": self.tr_acc, 
                }, 
                step=self.n_total_iter
            )

            wandb.log(
                {
                    "train/learning_rate": self.lr_this_step,
                    "train/speed": time.time() - self.last_log,
                }, 
                step=self.n_total_iter
            )

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")

        # let us do evaluation on the eval just for bookkeeping.
        # make sure this is not intervening your training in anyway
        # otherwise, data is leaking!
        if 'race' in self.task_name:
            result = eval_model_dataloader(
                self.student_encoder, self.student_classifier, 
                self.eval_dataset, self.device, False
            )
        else:
            result = eval_model_dataloader_nli(
                self.task_name.lower(), self.eval_label_ids, 
                self.student_encoder, self.student_classifier, self.eval_dataset,
                self.kd_model, self.num_labels, self.device, 
                self.weights, self.fc_layer_idx, self.output_mode
            )
        log_eval = open(os.path.join(self.output_dir, 'eval_log.txt'), 'a', buffering=1)
        if self.task_name in ['CoLA']:
            print('{},{},{}'.format(self.epoch+1, result['mcc'], result['eval_loss']), file=log_eval)
        else:
            if 'race' in self.task_name:
                print('{},{},{}'.format(self.epoch+1, result['acc'], result['loss']), file=log_eval)
            else:
                print('{},{},{}'.format(self.epoch+1, result['acc'], result['eval_loss']), file=log_eval)
        log_eval.close()
        
        self.save_checkpoint()
        if self.is_wandb:
            wandb.log(
                {
                    "epoch/loss": self.total_loss_epoch / self.n_iter, 
                    'epoch': self.epoch
                }
            )
            if self.task_name in ['CoLA']:
                wandb.log(
                    {
                        "epoch/eval_mcc": result['mcc'], 
                        "epoch/eval_loss": result['eval_loss'], 
                        'epoch': self.epoch
                    }
                )
            else:
                if 'race' in self.task_name:
                    wandb.log(
                        {
                            "epoch/eval_acc": result['acc'], 
                            "epoch/eval_loss": result['loss'], 
                            'epoch': self.epoch
                        }
                    )
                else:
                    wandb.log(
                        {
                            "epoch/eval_acc": result['acc'], 
                            "epoch/eval_loss": result['eval_loss'], 
                            'epoch': self.epoch
                        }
                    ) 

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0
        
        self.acc_tr_loss = 0
        self.acc_tr_kd_loss = 0
        self.acc_tr_ce_loss = 0
        self.acc_tr_acc = 0
    
    def save_checkpoint(self, checkpoint_name=None):
        if checkpoint_name == None:
            if self.n_gpu > 1:
                torch.save(
                    self.student_encoder.module.state_dict(), 
                    os.path.join(self.output_dir, self.output_model_file + f'_e.{self.epoch}.encoder.pkl')
                )
                torch.save(
                    self.student_classifier.module.state_dict(), 
                    os.path.join(self.output_dir, self.output_model_file + f'_e.{self.epoch}.cls.pkl')
                )
            else:
                torch.save(
                    self.student_encoder.state_dict(), 
                    os.path.join(self.output_dir, self.output_model_file + f'_e.{self.epoch}.encoder.pkl')
                )
                torch.save(
                    self.student_classifier.state_dict(), 
                    os.path.join(self.output_dir, self.output_model_file + f'_e.{self.epoch}.cls.pkl')
                )
        else:
            if self.n_gpu > 1:
                torch.save(
                    self.student_encoder.module.state_dict(), 
                    os.path.join(self.output_dir, "encoder."+checkpoint_name)
                )
                torch.save(
                    self.student_classifier.module.state_dict(), 
                    os.path.join(self.output_dir, "cls."+checkpoint_name)
                )
            else:
                torch.save(
                    self.student_encoder.state_dict(), 
                    os.path.join(self.output_dir, "encoder."+checkpoint_name)
                )
                torch.save(
                    self.student_classifier.state_dict(), 
                    os.path.join(self.output_dir, "cls."+checkpoint_name)
                )


In [21]:
distiller = TaskSpecificCausalDistiller(
    args, 
    train_dataloader, eval_dataloader, 
    eval_label_ids, num_labels, output_mode,
    student_encoder, student_classifier,
    teacher_encoder, teacher_classifier,
)

02/11/2022 01:43:22 - INFO - __main__ -   Initializing Normal Distiller (Task Specific)
02/11/2022 01:43:22 - INFO - __main__ -   FP16 is not activated, use BertAdam


In [22]:
distiller.is_diito = True

In [23]:
logger.info("Hey Zen: Let's go get some drinks.")
distiller.train()

02/11/2022 01:43:23 - INFO - __main__ -   Hey Zen: Let's go get some drinks.
Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
-Iter:   0%|          | 0/32 [00:00<?, ?it/s][A
Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([3, 1, 1, 3, 2, 3, 1, 1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 2, 1, 3, 1, 5, 2, 3,
        2, 3, 4, 1, 1, 3, 1, 1])
tensor([3, 1, 1, 3, 2, 3, 1, 1, 1, 1, 1, 3, 3, 2, 2, 1, 1, 2, 1, 3, 1, 5, 2, 3,
        2, 3, 4, 1, 1, 3, 1, 1])





ZeroDivisionError: division by zero

In [None]:
if args.do_eval:
    # Save a trained model and the associated configuration
    if 'race' in task_name:
        result = eval_model_dataloader(student_encoder, student_classifier, eval_dataloader, device, False)
    else:
        result = eval_model_dataloader_nli(args.task_name.lower(), eval_label_ids, student_encoder, student_classifier, eval_dataloader,
                                           args.kd_model, num_labels, device, args.weights, args.fc_layer_idx, output_mode)

    output_test_file = os.path.join(args.output_dir, "eval_results_" + output_model_file + '.txt')
    with open(output_test_file, "w") as writer:
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key]))) 
            