In [1]:
#comment this if you are not using AIT proxy...
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [2]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_scheduler,
    set_seed,
)
from tqdm.auto import tqdm

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
from io import open
import torch
import json
from glob import glob
import numpy as np
import pandas as pd
from tqdm import tqdm

In [4]:
import argparse
import logging
import math
import os
import random
from itertools import chain

## 1.Load Dataset

### Preprocessing the datasets.

In [5]:
from accelerate import Accelerator

accelerator = Accelerator()

In [6]:
model_checkpoint = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
# PAD_TOKEN = '<pad>'
# tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
# tokenizer

In [7]:
class Wikitext_Dataset:
    def __init__(self, path):
        self.train = os.path.join(path, 'train/train.txt')
        self.valid = os.path.join(path, 'valid/valid.txt')
        self.test  = os.path.join(path, 'test/test.txt')

    def build_corpus(self, path):
        files = open(path,'r')
        lines = []
        for line in files:
            line = line.strip().lower()
            if len(line) == 0:
                continue
            lines.append(line)
        return lines
path_files = './data/wikitext-2-add10b'
corpus = Wikitext_Dataset(path_files)
train_dataset = corpus.build_corpus(corpus.train)
valid_dataset = corpus.build_corpus(corpus.valid)
test_dataset  = corpus.build_corpus(corpus.test)

In [8]:
from datasets import Dataset
from datasets import DatasetDict
import pandas as pd

raw_datasets_train = Dataset.from_pandas(pd.DataFrame(data = {'text': train_dataset}))
raw_datasets_valid = Dataset.from_pandas(pd.DataFrame(data = {'text': valid_dataset}))
raw_datasets_test  = Dataset.from_pandas(pd.DataFrame(data = {'text': test_dataset}))
#remove .shuffle if you want to train the whole dataset....

raw_datasets = DatasetDict(
    {
        'train':raw_datasets_train,
        'validation':raw_datasets_valid,
        'test':raw_datasets_test
    }
)
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 23777
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 2461
    })
    test: Dataset({
        features: ['text'],
        num_rows: 2891
    })
})

## 2. Preprocessing

In [9]:
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

# def tokenize_function(example):
#     outputs =  tokenizer(example[text_column_name], truncation=True, padding='max_length')
#     input_batch = []
#     for input_ids in outputs["input_ids"]:
#         input_batch.append(input_ids)
#     return {"input_ids": input_batch}


preprocessing_num_workers = None
with accelerator.main_process_first():
    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )

tokenized_datasets

                                                                                             

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 23777
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 2461
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 2891
    })
})

In [10]:
block_size = 1024
if block_size is None:
    block_size = tokenizer.model_max_length
    if block_size > 1024:
        # logger.warning(
        #     f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
        #     "Picking 1024 instead. You can change that default value by passing --block_size xxx."
        # )
        block_size = 1024
else:
    if block_size > tokenizer.model_max_length:
        # logger.warning(
        #     f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
        #     f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
        # )
        block_size = min(block_size, tokenizer.model_max_length)
    
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [11]:
# # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# # to preprocess.
# #
# # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
preprocessing_num_workers = 1
with accelerator.main_process_first():
    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=preprocessing_num_workers,
        desc=f"Grouping texts in chunks of {block_size}",
    )
lm_datasets.set_format("torch")
lm_datasets

                                                                                                

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2405
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 255
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 290
    })
})

In [12]:
small_train_dataset = lm_datasets["train"].shuffle(seed=55) #.select(range(100))
small_eval_dataset = lm_datasets["validation"].shuffle(seed=55) #.select(range(10))
small_test_dataset = lm_datasets["test"].shuffle(seed=55)

## 3. Dataloaders

In [13]:
from torch.utils.data import DataLoader
per_device_train_batch_size = 8
per_device_eval_batch_size = 8

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=per_device_train_batch_size, pin_memory=True)
val_dataloader = DataLoader(small_eval_dataset, batch_size=per_device_eval_batch_size, pin_memory=True)
test_dataloader = DataLoader(small_test_dataset, batch_size=per_device_eval_batch_size)

In [14]:
#checking chucking
for i in train_dataloader:
    print(i['input_ids'].shape, i['labels'].shape)
    break
for i in val_dataloader:
    print(i['input_ids'].shape, i['labels'].shape)
    break
for i in test_dataloader:
    print(i['input_ids'].shape, i['labels'].shape)
    break

torch.Size([8, 1024]) torch.Size([8, 1024])
torch.Size([8, 1024]) torch.Size([8, 1024])
torch.Size([8, 1024]) torch.Size([8, 1024])


## 4. Model

In [15]:
# config = AutoConfig.from_pretrained(model_checkpoint, tie_word_embeddings=False)
# model = AutoModelForCausalLM.from_config(config)
# model.resize_token_embeddings(len(tokenizer))

In [16]:
# model.config

In [17]:
# model

In [18]:
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel
# Define the configuration for the student model
teacher_config = GPT2Config(
    vocab_size=50257,    # Example vocabulary size
    n_positions=1024,    # Example sequence length
    n_embd=768,         # Adjust the embedding dimension to match teacher's
    n_layer=48,  # Number of student layers
)

# Initialize the student model architecture
teacher_model = GPT2LMHeadModel(config=teacher_config)
teacher_model.train()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [19]:
# # Create a configuration for a 48-layer GPT2 model
# teacher = 'gpt2-xl'
# teacher_config = AutoConfig.from_pretrained(teacher)#, tie_word_embeddings=False)
# teacher_model = AutoModelForCausalLM.from_pretrained(
#     teacher,
#     config=teacher_config,
# )
# teacher_model.train()

In [20]:
import math
def pseudo_uniform_selection(n, k):
    # Require: n > k; n mod k = 0; n mod 2 = 0
    # assert n > k and n % k == 0 and n % 2 == 0, "Invalid input"
    
    step = math.floor(n / k)
    start = 0
    end = n - 1
    selection = []
    while start <= end:
        selection.append(start)
        selection.append(end)
        start += step
        end -= step
    selection.sort()
    return selection

# Select the layers to copy from the teacher model
teacher_layers = teacher_model.config.n_layer
student_layers = 6
teacher_layers_to_use = pseudo_uniform_selection(teacher_layers, student_layers)
print(teacher_layers_to_use)

[0, 8, 16, 31, 39, 47]


In [21]:
# Select the layers to copy from the teacher model
# teacher_layers_to_use = [num for num in range(teacher_model.config.n_layer) if num % 8 == 0]  # Indices of layers to copy
# print(len(teacher_layers_to_use))

# Define the configuration for the student model
student_config = GPT2Config(
    vocab_size=50257,    # Example vocabulary size
    n_positions=1024,    # Example sequence length
    n_embd=768,         # Adjust the embedding dimension to match teacher's
    n_layer=len(teacher_layers_to_use),  # Number of student layers
)

# Initialize the student model architecture
student_model = GPT2LMHeadModel(config=student_config)

# Copy teacher layers to student
for student_layer_idx, teacher_layer_idx in enumerate(teacher_layers_to_use):
    teacher_layer = teacher_model.transformer.h[teacher_layer_idx]
    student_layer = student_model.transformer.h[student_layer_idx]
    student_layer.load_state_dict(teacher_layer.state_dict())

# Now you can use the student model for further tasks
student_model.train()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [22]:
# # Create a configuration for a 12-layer GPT2 model
# student = "gpt2"
# student_config = AutoConfig.from_pretrained(student) #, tie_word_embeddings=False)
# student_model = AutoModelForCausalLM.from_pretrained(
#     student,
#     config=student_config,
# )
# student_model.train()

In [23]:
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
weight_decay = 0
teacher_optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in teacher_model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p
            for n, p in teacher_model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]

student_optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in student_model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p
            for n, p in student_model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
# params=model.parameters()
teacher_optimizer = torch.optim.Adam(teacher_optimizer_grouped_parameters, lr=1e-4)
student_optimizer = torch.optim.Adam(student_optimizer_grouped_parameters, lr=1e-4)

## Accelator

In [24]:
# Prepare everything with our `accelerator`.
# model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
#     model, optimizer, train_dataloader, eval_dataloader
# )
teacher_model = accelerator.prepare(teacher_model)
student_model, student_optimizer, train_dataloader, val_dataloader = accelerator.prepare(
    student_model, student_optimizer, train_dataloader, val_dataloader
)

In [25]:
from transformers import get_scheduler
import math

gradient_accumulation_steps = 1
num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
num_train_epochs = 10
max_train_steps = num_train_epochs * num_update_steps_per_epoch

teacher_lr_scheduler = get_scheduler(
    "linear",
    optimizer=teacher_optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

student_lr_scheduler = get_scheduler(
    "linear",
    optimizer=student_optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

total_batch_size = (
        per_device_train_batch_size
        * accelerator.num_processes
        * gradient_accumulation_steps
    )

## Ghost clipping: memory saving differentially private learning
Turning on ghost clipping requires changing only 1 line. You should notice a drastic reduction in peak GPU memory usage once this is turned on, at a potential cost of slower training speed. One might find this especially useful when constrained to only use older GPUs with small VRAMs or fitting super large models.

In [26]:
# !pip install ml_swissknife
# !pip install opt_einsum

In [27]:
import transformers, torch
from private_transformers import PrivacyEngine
dp = False
if dp == True:
    #student_model
    privacy_engine = PrivacyEngine(
        student_model,
        batch_size=per_device_train_batch_size,
        sample_size=len(lm_datasets['train']),
        epochs=1,
        max_grad_norm=0.1,
        target_epsilon=3,
        clipping_mode="ghost",  # The only change you need to make!
    )
    privacy_engine.attach(student_optimizer)
    #Teacher Model
    privacy_engine = PrivacyEngine(
        teacher_model,
        batch_size=per_device_train_batch_size,
        sample_size=len(lm_datasets['train']),
        epochs=per_device_train_batch_size,
        max_grad_norm=0.1,
        target_epsilon=3,
        clipping_mode="ghost",  # The only change you need to make!
    )
    privacy_engine.attach(teacher_optimizer)

else :
    privacy_engine = None

In [28]:
privacy_engine

In [29]:
delta = 1.0/42061 # We instead use the accountant from Gopi et al. (2021) as described in the paper.

### Loss Objective 
The Kullback-Leibler divergence loss. For tensors of the same shape $y_{pred}, y_{true}$ where $y_{pred}$ is the input and $y_{true}$ ​ is the target, we define the pointwise KL-divergence as 

$$L(y_{pred}, y_{true}) = y_{pred}\cdot \log \frac{y_{true}}{y_{pred}}  = y_{true} \cdot (\log y_{true} -\log y_{true})$$

format : torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)
more infomation click [link](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)

In [30]:
def loss_fn_kd(student_outputs, labels, teacher_outputs, alpha = 0.1, T = 1, batch_size = per_device_train_batch_size):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    # student_outputs.logits.shape = (batch_size, seq_len, vocab_size)
    # teacher_outputs.logits.shape = (batch_size, seq_len, vocab_size)
    # labels.shape = (batch_size, seq_len)

    # Define loss functions (Cross-Entropy for student, KL Divergence for distillation)
    student_loss_fn = nn.CrossEntropyLoss()
    distillation_loss_fn = nn.KLDivLoss()

    # Flatten the logits and targets
    student_logits_flat = student_outputs.logits.view(-1, student_outputs.logits.size(-1))
    labels_flat = labels.view(-1)
    # student_logits_flat shape = (batch_size * seq_len, vocab_size)
    # labels_flat shape = (batch_size * seq_len, )

    # Calculate Cross-Entropy loss for student
    student_loss = student_loss_fn(student_logits_flat, labels_flat)

    # Calculate distillation loss (using KL Divergence)
    distillation_loss = distillation_loss_fn(
        torch.log_softmax(student_logits_flat/T, dim=-1),
        torch.softmax(teacher_outputs.logits.view(-1, teacher_outputs.logits.size(-1))/T, dim=-1)* (T ** 2)
    )
    
    # Combine the two losses (you can adjust the weighting factor)
    total_loss = (1-alpha)* student_loss + alpha * distillation_loss #label smoothing
    print('student_loss', student_loss) 
    print('distillation_loss', distillation_loss)
    # print('total_loss',total_loss)
    return total_loss


In [31]:
class RunningAverage():
    """A simple class that maintains the running average of a quantity
    
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """
    def __init__(self):
        self.steps = 0
        self.total = 0
    
    def update(self, val):
        self.total += val
        self.steps += 1
    
    def __call__(self):
        return self.total/float(self.steps)

In [32]:
# Defining train_kd & train_and_evaluate_kd functions
def train_kd(student_model, teacher_model, optimizer, train_dataloader):
    student_model.train()
    teacher_model.eval()
    
    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = RunningAverage()
    
    for step, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        outputs_student_batch = student_model(**batch)
        labels_batch = batch['labels']
        
        # get one batch output from teacher_outputs list
        with torch.no_grad():
            output_teacher_batch = teacher_model(**batch)
            
        loss = loss_fn_kd(outputs_student_batch, labels_batch, output_teacher_batch)
        # loss = outputs_student_batch.loss
        loss = loss / gradient_accumulation_steps
        # loss = loss.reshape(-1)
        # accelerator.backward(loss)
        if (
            step % gradient_accumulation_steps == 0
            or step == len(train_dataloader) - 1
        ):
            # Perform one optimization step with the PrivacyEngine
            if dp:
                optimizer.step(loss=loss)
            else:
                accelerator.backward(loss)
                optimizer.step()
            student_lr_scheduler.step()
            # optimizer.zero_grad()
            # progress_bar.update(1)
            # completed_steps += 1

        # if completed_steps >= max_train_steps:
        #     break

In [33]:
def evaluate_kd(model, eval_dataloader):
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(
            accelerator.gather(loss.repeat(per_device_eval_batch_size))
        )

    losses = torch.cat(losses)
    losses = losses[: len(small_eval_dataset)]
    try:
        perplexity = math.exp(torch.mean(losses))
    except OverflowError:
        perplexity = float("inf")
    return perplexity, torch.mean(losses)

In [34]:
criterion = nn.KLDivLoss(reduction="none") 
# distillation_loss = distillation_loss_fn(
#             torch.log_softmax(student_outputs, dim=-1),
#             torch.softmax(teacher_outputs, dim=-1)
#         )

In [35]:
def train_and_evaluate_kd(student_model, teacher_model, train_dataloader, val_dataloader, optimizer, save_path, restore_file=None):
    # Only show the progress bar once on each machine.
    # progress_bar = tqdm(
    #     range(max_train_steps), disable=not accelerator.is_local_main_process
    # )
    # completed_steps = 0
    best_val_perplexity = float("inf")
    
    for epoch in range(num_train_epochs):
        # compute number of batches in one epoch (one full pass over the training set)
        train_kd(student_model, teacher_model, optimizer, train_dataloader)
        
        # Evaluate for one epoch on validation set
        perplexity, loss = evaluate_kd(student_model, val_dataloader)

        # logger.info(f"epoch {epoch}: perplexity: {perplexity}")
        print(f"epoch {epoch}: perplexity: {perplexity} : loss {loss}")
    
        if dp:
            # Printing epsilon from opacus privacy engine at the end of each epoch
            eps, alpha = optimizer.privacy_engine.get_privacy_spent(delta)
            print("End of epoch {}, we have epsilon {} for alpha {}".format(epoch, eps, alpha))
    
        if perplexity < best_val_perplexity and save_path is not None:
            best_val_perplexity = perplexity
            
            print(f"saved model! epoch {epoch}: perplexity: {best_val_perplexity}")
            # torch.save(student_model.state_dict(), save_path)

In [None]:
save_path = f'models/{student_model.__class__.__name__}_distill_nodp.pt'
train_and_evaluate_kd(student_model, teacher_model, train_dataloader, val_dataloader, student_optimizer, save_path, restore_file=None)



student_loss tensor(8.0095, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1345e-06, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 1/301 [00:01<07:34,  1.51s/it]

student_loss tensor(7.3910, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.4652e-06, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|          | 2/301 [00:02<06:46,  1.36s/it]

student_loss tensor(7.0843, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7436e-06, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|          | 3/301 [00:04<06:30,  1.31s/it]

student_loss tensor(7.5291, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(1.7392e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|▏         | 4/301 [00:05<06:22,  1.29s/it]

student_loss tensor(6.4207, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7906e-06, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 5/301 [00:06<06:18,  1.28s/it]

student_loss tensor(5.7539, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(1.3059e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 6/301 [00:07<06:14,  1.27s/it]

student_loss tensor(5.6416, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(1.3344e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 7/301 [00:09<06:13,  1.27s/it]

student_loss tensor(4.8405, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(1.7833e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 8/301 [00:10<06:11,  1.27s/it]

student_loss tensor(4.6089, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(1.9325e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 9/301 [00:11<06:09,  1.27s/it]

student_loss tensor(4.4022, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.0683e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 10/301 [00:12<06:08,  1.27s/it]

student_loss tensor(4.4880, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.0418e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▎         | 11/301 [00:14<06:07,  1.27s/it]

student_loss tensor(4.0189, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.2938e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▍         | 12/301 [00:15<06:06,  1.27s/it]

student_loss tensor(3.9951, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.2856e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▍         | 13/301 [00:16<06:04,  1.27s/it]

student_loss tensor(3.8194, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.4705e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  5%|▍         | 14/301 [00:17<06:03,  1.27s/it]

student_loss tensor(3.5399, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.6384e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  5%|▍         | 15/301 [00:19<06:03,  1.27s/it]

student_loss tensor(3.4722, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.6464e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  5%|▌         | 16/301 [00:20<06:01,  1.27s/it]

student_loss tensor(3.3308, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.7087e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  6%|▌         | 17/301 [00:21<06:00,  1.27s/it]

student_loss tensor(3.1312, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.8432e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  6%|▌         | 18/301 [00:23<05:59,  1.27s/it]

student_loss tensor(3.0781, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.8237e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  6%|▋         | 19/301 [00:24<05:59,  1.27s/it]

student_loss tensor(2.9647, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.8635e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  7%|▋         | 20/301 [00:25<05:57,  1.27s/it]

student_loss tensor(2.9070, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.8837e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  7%|▋         | 21/301 [00:26<05:57,  1.28s/it]

student_loss tensor(2.7380, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.9036e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  7%|▋         | 22/301 [00:28<05:55,  1.27s/it]

student_loss tensor(2.5957, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.0156e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  8%|▊         | 23/301 [00:29<05:55,  1.28s/it]

student_loss tensor(2.5049, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.0505e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  8%|▊         | 24/301 [00:30<05:53,  1.28s/it]

student_loss tensor(2.4033, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.1493e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  8%|▊         | 25/301 [00:31<05:53,  1.28s/it]

student_loss tensor(2.3724, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.0560e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  9%|▊         | 26/301 [00:33<05:51,  1.28s/it]

student_loss tensor(2.4015, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.9595e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  9%|▉         | 27/301 [00:34<05:50,  1.28s/it]

student_loss tensor(2.3158, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(2.9585e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  9%|▉         | 28/301 [00:35<05:49,  1.28s/it]

student_loss tensor(2.1394, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.1533e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 10%|▉         | 29/301 [00:37<05:48,  1.28s/it]

student_loss tensor(2.1214, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.0923e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 10%|▉         | 30/301 [00:38<05:47,  1.28s/it]

student_loss tensor(2.0317, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.1276e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 10%|█         | 31/301 [00:39<05:45,  1.28s/it]

student_loss tensor(1.9546, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.2829e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 11%|█         | 32/301 [00:40<05:44,  1.28s/it]

student_loss tensor(1.9588, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.1648e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 11%|█         | 33/301 [00:42<05:43,  1.28s/it]

student_loss tensor(1.8671, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.2914e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 11%|█▏        | 34/301 [00:43<05:42,  1.28s/it]

student_loss tensor(1.7186, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.4513e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 12%|█▏        | 35/301 [00:44<05:41,  1.28s/it]

student_loss tensor(1.6743, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.5078e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 12%|█▏        | 36/301 [00:46<05:40,  1.29s/it]

student_loss tensor(1.6348, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.5667e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 12%|█▏        | 37/301 [00:47<05:39,  1.29s/it]

student_loss tensor(1.6564, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.5279e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 13%|█▎        | 38/301 [00:48<05:38,  1.29s/it]

student_loss tensor(1.4967, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.6168e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 13%|█▎        | 39/301 [00:49<05:36,  1.29s/it]

student_loss tensor(1.5161, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.5447e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 13%|█▎        | 40/301 [00:51<05:36,  1.29s/it]

student_loss tensor(1.4503, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.6028e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 14%|█▎        | 41/301 [00:52<05:34,  1.29s/it]

student_loss tensor(1.3807, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.7697e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 14%|█▍        | 42/301 [00:53<05:34,  1.29s/it]

student_loss tensor(1.3936, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.6849e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 14%|█▍        | 43/301 [00:55<05:32,  1.29s/it]

student_loss tensor(1.2832, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.8915e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 15%|█▍        | 44/301 [00:56<05:31,  1.29s/it]

student_loss tensor(1.2678, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.8823e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 15%|█▍        | 45/301 [00:57<05:30,  1.29s/it]

student_loss tensor(1.2098, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.9482e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 15%|█▌        | 46/301 [00:58<05:29,  1.29s/it]

student_loss tensor(1.2618, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(3.8763e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 16%|█▌        | 47/301 [01:00<05:28,  1.29s/it]

student_loss tensor(1.1056, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.0919e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 16%|█▌        | 48/301 [01:01<05:27,  1.29s/it]

student_loss tensor(1.1107, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.0901e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 16%|█▋        | 49/301 [01:02<05:25,  1.29s/it]

student_loss tensor(1.1113, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.0484e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 17%|█▋        | 50/301 [01:04<05:24,  1.29s/it]

student_loss tensor(1.0115, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.3017e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 17%|█▋        | 51/301 [01:05<05:23,  1.30s/it]

student_loss tensor(1.1086, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.0208e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 17%|█▋        | 52/301 [01:06<05:22,  1.29s/it]

student_loss tensor(0.9704, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.2636e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 18%|█▊        | 53/301 [01:08<05:21,  1.30s/it]

student_loss tensor(0.9601, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.2060e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 18%|█▊        | 54/301 [01:09<05:19,  1.29s/it]

student_loss tensor(0.9784, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.2319e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 18%|█▊        | 55/301 [01:10<05:18,  1.29s/it]

student_loss tensor(0.9580, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.2517e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 19%|█▊        | 56/301 [01:11<05:16,  1.29s/it]

student_loss tensor(0.9284, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.3285e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 19%|█▉        | 57/301 [01:13<05:15,  1.29s/it]

student_loss tensor(0.8436, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.4893e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 19%|█▉        | 58/301 [01:14<05:14,  1.29s/it]

student_loss tensor(0.8786, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.4665e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 20%|█▉        | 59/301 [01:15<05:13,  1.29s/it]

student_loss tensor(0.8139, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.5345e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 20%|█▉        | 60/301 [01:17<05:11,  1.29s/it]

student_loss tensor(0.8027, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.5755e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 20%|██        | 61/301 [01:18<05:10,  1.29s/it]

student_loss tensor(0.7404, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.6832e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 21%|██        | 62/301 [01:19<05:09,  1.30s/it]

student_loss tensor(0.7701, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.6828e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 21%|██        | 63/301 [01:21<05:08,  1.30s/it]

student_loss tensor(0.7501, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.6630e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 21%|██▏       | 64/301 [01:22<05:07,  1.30s/it]

student_loss tensor(0.7040, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.7416e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 22%|██▏       | 65/301 [01:23<05:06,  1.30s/it]

student_loss tensor(0.6982, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.7494e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 22%|██▏       | 66/301 [01:24<05:05,  1.30s/it]

student_loss tensor(0.7084, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.9225e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 22%|██▏       | 67/301 [01:26<05:03,  1.30s/it]

student_loss tensor(0.6720, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.8182e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 23%|██▎       | 68/301 [01:27<05:03,  1.30s/it]

student_loss tensor(0.6590, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.7967e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 23%|██▎       | 69/301 [01:28<05:01,  1.30s/it]

student_loss tensor(0.6610, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.8879e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 23%|██▎       | 70/301 [01:30<05:01,  1.30s/it]

student_loss tensor(0.6312, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.8330e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 24%|██▎       | 71/301 [01:31<04:59,  1.30s/it]

student_loss tensor(0.5944, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0079e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 24%|██▍       | 72/301 [01:32<04:59,  1.31s/it]

student_loss tensor(0.6299, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(4.9295e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 24%|██▍       | 73/301 [01:34<04:57,  1.30s/it]

student_loss tensor(0.6113, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0086e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 25%|██▍       | 74/301 [01:35<04:56,  1.30s/it]

student_loss tensor(0.5437, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0681e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 25%|██▍       | 75/301 [01:36<04:54,  1.30s/it]

student_loss tensor(0.6135, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0332e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 25%|██▌       | 76/301 [01:37<04:53,  1.30s/it]

student_loss tensor(0.5417, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0341e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 26%|██▌       | 77/301 [01:39<04:51,  1.30s/it]

student_loss tensor(0.5373, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.0506e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 26%|██▌       | 78/301 [01:40<04:50,  1.30s/it]

student_loss tensor(0.5180, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.2107e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 26%|██▌       | 79/301 [01:41<04:48,  1.30s/it]

student_loss tensor(0.5081, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.2883e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 27%|██▋       | 80/301 [01:43<04:47,  1.30s/it]

student_loss tensor(0.5074, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.2859e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 27%|██▋       | 81/301 [01:44<04:45,  1.30s/it]

student_loss tensor(0.4885, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.2916e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 27%|██▋       | 82/301 [01:45<04:45,  1.30s/it]

student_loss tensor(0.4760, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.3219e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 28%|██▊       | 83/301 [01:47<04:44,  1.30s/it]

student_loss tensor(0.4632, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.5456e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 28%|██▊       | 84/301 [01:48<04:42,  1.30s/it]

student_loss tensor(0.5180, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.2880e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 28%|██▊       | 85/301 [01:49<04:41,  1.30s/it]

student_loss tensor(0.4889, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.3313e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 29%|██▊       | 86/301 [01:50<04:40,  1.30s/it]

student_loss tensor(0.4739, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.3864e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 29%|██▉       | 87/301 [01:52<04:39,  1.31s/it]

student_loss tensor(0.4458, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.3954e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 29%|██▉       | 88/301 [01:53<04:37,  1.30s/it]

student_loss tensor(0.3981, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.5838e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 30%|██▉       | 89/301 [01:54<04:36,  1.30s/it]

student_loss tensor(0.4104, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7045e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 30%|██▉       | 90/301 [01:56<04:34,  1.30s/it]

student_loss tensor(0.4218, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.4964e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 30%|███       | 91/301 [01:57<04:33,  1.30s/it]

student_loss tensor(0.3943, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.6755e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 31%|███       | 92/301 [01:58<04:31,  1.30s/it]

student_loss tensor(0.4038, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.5893e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 31%|███       | 93/301 [02:00<04:31,  1.30s/it]

student_loss tensor(0.4044, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7176e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 31%|███       | 94/301 [02:01<04:29,  1.30s/it]

student_loss tensor(0.4042, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7227e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 32%|███▏      | 95/301 [02:02<04:28,  1.30s/it]

student_loss tensor(0.3646, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7799e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 32%|███▏      | 96/301 [02:03<04:27,  1.30s/it]

student_loss tensor(0.3780, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7628e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 32%|███▏      | 97/301 [02:05<04:25,  1.30s/it]

student_loss tensor(0.3639, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.6957e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 33%|███▎      | 98/301 [02:06<04:24,  1.30s/it]

student_loss tensor(0.3901, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.7973e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 33%|███▎      | 99/301 [02:07<04:23,  1.30s/it]

student_loss tensor(0.3475, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.8114e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 33%|███▎      | 100/301 [02:09<04:22,  1.30s/it]

student_loss tensor(0.3548, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.9052e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 34%|███▎      | 101/301 [02:10<04:20,  1.30s/it]

student_loss tensor(0.3167, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.8990e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 34%|███▍      | 102/301 [02:11<04:19,  1.30s/it]

student_loss tensor(0.3327, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.8165e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 34%|███▍      | 103/301 [02:13<04:18,  1.30s/it]

student_loss tensor(0.3357, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.9164e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 35%|███▍      | 104/301 [02:14<04:16,  1.30s/it]

student_loss tensor(0.3186, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.9816e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 35%|███▍      | 105/301 [02:15<04:15,  1.30s/it]

student_loss tensor(0.3446, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(5.8255e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 35%|███▌      | 106/301 [02:17<04:14,  1.30s/it]

student_loss tensor(0.3014, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.0380e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 36%|███▌      | 107/301 [02:18<04:12,  1.30s/it]

student_loss tensor(0.3023, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.0124e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 36%|███▌      | 108/301 [02:19<04:11,  1.31s/it]

student_loss tensor(0.2928, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.0505e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 36%|███▌      | 109/301 [02:20<04:10,  1.30s/it]

student_loss tensor(0.2853, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.0907e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 37%|███▋      | 110/301 [02:22<04:09,  1.30s/it]

student_loss tensor(0.2710, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.2265e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 37%|███▋      | 111/301 [02:23<04:07,  1.30s/it]

student_loss tensor(0.2776, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1486e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 37%|███▋      | 112/301 [02:24<04:06,  1.31s/it]

student_loss tensor(0.2733, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1963e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 38%|███▊      | 113/301 [02:26<04:05,  1.30s/it]

student_loss tensor(0.2783, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1679e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 38%|███▊      | 114/301 [02:27<04:04,  1.31s/it]

student_loss tensor(0.2817, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1357e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 38%|███▊      | 115/301 [02:28<04:02,  1.30s/it]

student_loss tensor(0.3077, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1456e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 39%|███▊      | 116/301 [02:30<04:01,  1.31s/it]

student_loss tensor(0.2716, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.1847e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 39%|███▉      | 117/301 [02:31<04:00,  1.31s/it]

student_loss tensor(0.2628, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.2101e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 39%|███▉      | 118/301 [02:32<03:59,  1.31s/it]

student_loss tensor(0.2482, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.3067e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 40%|███▉      | 119/301 [02:33<03:57,  1.31s/it]

student_loss tensor(0.2432, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.3680e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 40%|███▉      | 120/301 [02:35<03:56,  1.31s/it]

student_loss tensor(0.2524, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.2758e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 40%|████      | 121/301 [02:36<03:55,  1.31s/it]

student_loss tensor(0.2380, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.4395e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 41%|████      | 122/301 [02:37<03:54,  1.31s/it]

student_loss tensor(0.2392, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.3882e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 41%|████      | 123/301 [02:39<03:52,  1.31s/it]

student_loss tensor(0.2355, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.3522e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 41%|████      | 124/301 [02:40<03:51,  1.31s/it]

student_loss tensor(0.2264, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.4386e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 42%|████▏     | 125/301 [02:41<03:49,  1.31s/it]

student_loss tensor(0.3048, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.2555e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 42%|████▏     | 126/301 [02:43<03:48,  1.31s/it]

student_loss tensor(0.2464, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.3255e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 42%|████▏     | 127/301 [02:44<03:47,  1.31s/it]

student_loss tensor(0.2436, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.4622e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 43%|████▎     | 128/301 [02:45<03:45,  1.31s/it]

student_loss tensor(0.2201, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.5341e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 43%|████▎     | 129/301 [02:47<03:44,  1.31s/it]

student_loss tensor(0.2290, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.5324e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 43%|████▎     | 130/301 [02:48<03:43,  1.31s/it]

student_loss tensor(0.2253, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.4361e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 44%|████▎     | 131/301 [02:49<03:42,  1.31s/it]

student_loss tensor(0.2234, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.5793e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 44%|████▍     | 132/301 [02:50<03:40,  1.31s/it]

student_loss tensor(0.2128, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.6194e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 44%|████▍     | 133/301 [02:52<03:39,  1.31s/it]

student_loss tensor(0.2215, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.5862e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 45%|████▍     | 134/301 [02:53<03:37,  1.30s/it]

student_loss tensor(0.1998, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.6880e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 45%|████▍     | 135/301 [02:54<03:36,  1.31s/it]

student_loss tensor(0.2085, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.6869e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 45%|████▌     | 136/301 [02:56<03:35,  1.31s/it]

student_loss tensor(0.2024, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.6662e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 46%|████▌     | 137/301 [02:57<03:34,  1.31s/it]

student_loss tensor(0.1994, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.6695e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 46%|████▌     | 138/301 [02:58<03:33,  1.31s/it]

student_loss tensor(0.2054, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7151e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 46%|████▌     | 139/301 [03:00<03:31,  1.31s/it]

student_loss tensor(0.1995, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7143e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 47%|████▋     | 140/301 [03:01<03:30,  1.31s/it]

student_loss tensor(0.1892, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7795e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 47%|████▋     | 141/301 [03:02<03:28,  1.30s/it]

student_loss tensor(0.1737, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8544e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 47%|████▋     | 142/301 [03:04<03:27,  1.31s/it]

student_loss tensor(0.1832, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8093e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 48%|████▊     | 143/301 [03:05<03:25,  1.30s/it]

student_loss tensor(0.1889, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7973e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 48%|████▊     | 144/301 [03:06<03:25,  1.31s/it]

student_loss tensor(0.1753, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8371e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 48%|████▊     | 145/301 [03:07<03:23,  1.31s/it]

student_loss tensor(0.1551, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.9465e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 49%|████▊     | 146/301 [03:09<03:22,  1.31s/it]

student_loss tensor(0.1916, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.7754e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 49%|████▉     | 147/301 [03:10<03:21,  1.31s/it]

student_loss tensor(0.1763, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8648e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 49%|████▉     | 148/301 [03:11<03:20,  1.31s/it]

student_loss tensor(0.1693, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.9412e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 50%|████▉     | 149/301 [03:13<03:18,  1.31s/it]

student_loss tensor(0.1771, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.9449e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 50%|████▉     | 150/301 [03:14<03:17,  1.31s/it]

student_loss tensor(0.1867, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8751e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 50%|█████     | 151/301 [03:15<03:16,  1.31s/it]

student_loss tensor(0.1740, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0007e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 50%|█████     | 152/301 [03:17<03:14,  1.31s/it]

student_loss tensor(0.1751, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.9945e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 51%|█████     | 153/301 [03:18<03:13,  1.31s/it]

student_loss tensor(0.1833, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(6.8476e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 51%|█████     | 154/301 [03:19<03:11,  1.31s/it]

student_loss tensor(0.1651, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0663e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 51%|█████▏    | 155/301 [03:21<03:10,  1.31s/it]

student_loss tensor(0.1682, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0881e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 52%|█████▏    | 156/301 [03:22<03:09,  1.30s/it]

student_loss tensor(0.1568, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0803e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 52%|█████▏    | 157/301 [03:23<03:08,  1.31s/it]

student_loss tensor(0.1755, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0372e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 52%|█████▏    | 158/301 [03:24<03:06,  1.31s/it]

student_loss tensor(0.1586, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0326e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 53%|█████▎    | 159/301 [03:26<03:05,  1.31s/it]

student_loss tensor(0.1511, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.1871e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 53%|█████▎    | 160/301 [03:27<03:04,  1.31s/it]

student_loss tensor(0.1461, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2041e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 53%|█████▎    | 161/301 [03:28<03:03,  1.31s/it]

student_loss tensor(0.1571, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.0730e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 54%|█████▍    | 162/301 [03:30<03:01,  1.31s/it]

student_loss tensor(0.1379, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2837e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 54%|█████▍    | 163/301 [03:31<03:00,  1.31s/it]

student_loss tensor(0.1505, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.1435e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 54%|█████▍    | 164/301 [03:32<02:58,  1.31s/it]

student_loss tensor(0.1514, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.1953e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 55%|█████▍    | 165/301 [03:34<02:57,  1.31s/it]

student_loss tensor(0.1625, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.1172e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 55%|█████▌    | 166/301 [03:35<02:56,  1.31s/it]

student_loss tensor(0.1478, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2058e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 55%|█████▌    | 167/301 [03:36<02:55,  1.31s/it]

student_loss tensor(0.1644, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.1913e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 56%|█████▌    | 168/301 [03:38<02:53,  1.31s/it]

student_loss tensor(0.1559, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2863e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 56%|█████▌    | 169/301 [03:39<02:52,  1.31s/it]

student_loss tensor(0.1339, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2847e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 56%|█████▋    | 170/301 [03:40<02:51,  1.31s/it]

student_loss tensor(0.1238, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.3021e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 57%|█████▋    | 171/301 [03:41<02:50,  1.31s/it]

student_loss tensor(0.1371, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.3605e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 57%|█████▋    | 172/301 [03:43<02:48,  1.31s/it]

student_loss tensor(0.1324, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.3307e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 57%|█████▋    | 173/301 [03:44<02:47,  1.31s/it]

student_loss tensor(0.1268, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4002e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 58%|█████▊    | 174/301 [03:45<02:46,  1.31s/it]

student_loss tensor(0.1279, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.3120e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 58%|█████▊    | 175/301 [03:47<02:44,  1.30s/it]

student_loss tensor(0.1419, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.2805e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 58%|█████▊    | 176/301 [03:48<02:43,  1.31s/it]

student_loss tensor(0.1173, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4096e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 59%|█████▉    | 177/301 [03:49<02:41,  1.30s/it]

student_loss tensor(0.1159, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4532e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 59%|█████▉    | 178/301 [03:51<02:40,  1.31s/it]

student_loss tensor(0.1277, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4754e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 59%|█████▉    | 179/301 [03:52<02:38,  1.30s/it]

student_loss tensor(0.1192, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4373e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 60%|█████▉    | 180/301 [03:53<02:37,  1.30s/it]

student_loss tensor(0.1316, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4841e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 60%|██████    | 181/301 [03:55<02:36,  1.30s/it]

student_loss tensor(0.1262, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5689e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 60%|██████    | 182/301 [03:56<02:34,  1.30s/it]

student_loss tensor(0.1196, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5230e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 61%|██████    | 183/301 [03:57<02:33,  1.30s/it]

student_loss tensor(0.1242, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5229e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 61%|██████    | 184/301 [03:58<02:32,  1.30s/it]

student_loss tensor(0.1181, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5116e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 61%|██████▏   | 185/301 [04:00<02:30,  1.30s/it]

student_loss tensor(0.1255, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5277e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 62%|██████▏   | 186/301 [04:01<02:29,  1.30s/it]

student_loss tensor(0.1269, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5436e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 62%|██████▏   | 187/301 [04:02<02:28,  1.30s/it]

student_loss tensor(0.1065, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6297e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 62%|██████▏   | 188/301 [04:04<02:26,  1.30s/it]

student_loss tensor(0.1196, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5301e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 63%|██████▎   | 189/301 [04:05<02:25,  1.30s/it]

student_loss tensor(0.1201, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.4926e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 63%|██████▎   | 190/301 [04:06<02:24,  1.31s/it]

student_loss tensor(0.1138, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7073e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 63%|██████▎   | 191/301 [04:08<02:23,  1.30s/it]

student_loss tensor(0.1078, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6980e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 64%|██████▍   | 192/301 [04:09<02:21,  1.30s/it]

student_loss tensor(0.1084, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6077e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 64%|██████▍   | 193/301 [04:10<02:20,  1.30s/it]

student_loss tensor(0.0997, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7246e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 64%|██████▍   | 194/301 [04:11<02:19,  1.30s/it]

student_loss tensor(0.1141, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.5669e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 65%|██████▍   | 195/301 [04:13<02:18,  1.30s/it]

student_loss tensor(0.1134, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6593e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 65%|██████▌   | 196/301 [04:14<02:17,  1.30s/it]

student_loss tensor(0.1143, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6691e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 65%|██████▌   | 197/301 [04:15<02:15,  1.31s/it]

student_loss tensor(0.1017, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7202e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 66%|██████▌   | 198/301 [04:17<02:14,  1.31s/it]

student_loss tensor(0.1142, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.6958e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 66%|██████▌   | 199/301 [04:18<02:13,  1.31s/it]

student_loss tensor(0.0989, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7734e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 66%|██████▋   | 200/301 [04:19<02:11,  1.31s/it]

student_loss tensor(0.1021, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7690e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 67%|██████▋   | 201/301 [04:21<02:10,  1.30s/it]

student_loss tensor(0.0998, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7530e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 67%|██████▋   | 202/301 [04:22<02:09,  1.31s/it]

student_loss tensor(0.1060, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7599e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 67%|██████▋   | 203/301 [04:23<02:07,  1.30s/it]

student_loss tensor(0.1038, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7817e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 68%|██████▊   | 204/301 [04:24<02:06,  1.31s/it]

student_loss tensor(0.0954, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8298e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 68%|██████▊   | 205/301 [04:26<02:05,  1.30s/it]

student_loss tensor(0.0975, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7851e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 68%|██████▊   | 206/301 [04:27<02:04,  1.31s/it]

student_loss tensor(0.0892, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9053e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 69%|██████▉   | 207/301 [04:28<02:02,  1.31s/it]

student_loss tensor(0.1070, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.7691e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 69%|██████▉   | 208/301 [04:30<02:01,  1.30s/it]

student_loss tensor(0.0929, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8618e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 69%|██████▉   | 209/301 [04:31<02:00,  1.31s/it]

student_loss tensor(0.0950, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8679e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 70%|██████▉   | 210/301 [04:32<01:58,  1.31s/it]

student_loss tensor(0.0937, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8926e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 70%|███████   | 211/301 [04:34<01:57,  1.31s/it]

student_loss tensor(0.0923, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8680e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 70%|███████   | 212/301 [04:35<01:56,  1.31s/it]

student_loss tensor(0.0905, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9160e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 71%|███████   | 213/301 [04:36<01:55,  1.31s/it]

student_loss tensor(0.0953, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8961e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 71%|███████   | 214/301 [04:38<01:53,  1.31s/it]

student_loss tensor(0.0991, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.8420e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 71%|███████▏  | 215/301 [04:39<01:52,  1.31s/it]

student_loss tensor(0.0843, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9916e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 72%|███████▏  | 216/301 [04:40<01:50,  1.30s/it]

student_loss tensor(0.0882, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9502e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 72%|███████▏  | 217/301 [04:41<01:49,  1.31s/it]

student_loss tensor(0.0886, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0091e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 72%|███████▏  | 218/301 [04:43<01:48,  1.31s/it]

student_loss tensor(0.0907, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9418e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 73%|███████▎  | 219/301 [04:44<01:47,  1.31s/it]

student_loss tensor(0.0811, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0446e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 73%|███████▎  | 220/301 [04:45<01:46,  1.31s/it]

student_loss tensor(0.0858, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9937e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 73%|███████▎  | 221/301 [04:47<01:44,  1.31s/it]

student_loss tensor(0.0825, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0864e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 74%|███████▍  | 222/301 [04:48<01:43,  1.31s/it]

student_loss tensor(0.0879, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(7.9955e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 74%|███████▍  | 223/301 [04:49<01:42,  1.31s/it]

student_loss tensor(0.0901, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0171e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 74%|███████▍  | 224/301 [04:51<01:40,  1.31s/it]

student_loss tensor(0.0873, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0514e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 75%|███████▍  | 225/301 [04:52<01:39,  1.31s/it]

student_loss tensor(0.0783, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1732e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 75%|███████▌  | 226/301 [04:53<01:38,  1.31s/it]

student_loss tensor(0.0806, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1111e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 75%|███████▌  | 227/301 [04:55<01:36,  1.30s/it]

student_loss tensor(0.0865, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0586e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 76%|███████▌  | 228/301 [04:56<01:35,  1.30s/it]

student_loss tensor(0.0806, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2333e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 76%|███████▌  | 229/301 [04:57<01:33,  1.30s/it]

student_loss tensor(0.0909, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0914e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 76%|███████▋  | 230/301 [04:58<01:32,  1.30s/it]

student_loss tensor(0.0840, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.0716e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 77%|███████▋  | 231/301 [05:00<01:31,  1.30s/it]

student_loss tensor(0.0847, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1772e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 77%|███████▋  | 232/301 [05:01<01:29,  1.30s/it]

student_loss tensor(0.0818, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1756e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 77%|███████▋  | 233/301 [05:02<01:28,  1.30s/it]

student_loss tensor(0.0736, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2520e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 78%|███████▊  | 234/301 [05:04<01:27,  1.31s/it]

student_loss tensor(0.0790, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1765e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 78%|███████▊  | 235/301 [05:05<01:26,  1.31s/it]

student_loss tensor(0.0752, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2300e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 78%|███████▊  | 236/301 [05:06<01:24,  1.31s/it]

student_loss tensor(0.0773, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2233e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 79%|███████▊  | 237/301 [05:08<01:23,  1.31s/it]

student_loss tensor(0.0844, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2252e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 79%|███████▉  | 238/301 [05:09<01:22,  1.31s/it]

student_loss tensor(0.0695, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2593e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 79%|███████▉  | 239/301 [05:10<01:20,  1.31s/it]

student_loss tensor(0.0767, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.1992e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 80%|███████▉  | 240/301 [05:12<01:19,  1.30s/it]

student_loss tensor(0.0746, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2374e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 80%|████████  | 241/301 [05:13<01:18,  1.31s/it]

student_loss tensor(0.0754, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2826e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 80%|████████  | 242/301 [05:14<01:16,  1.30s/it]

student_loss tensor(0.0741, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.2523e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 81%|████████  | 243/301 [05:15<01:15,  1.31s/it]

student_loss tensor(0.0797, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3482e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 81%|████████  | 244/301 [05:17<01:14,  1.30s/it]

student_loss tensor(0.0672, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3295e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 81%|████████▏ | 245/301 [05:18<01:13,  1.31s/it]

student_loss tensor(0.0694, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4138e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 82%|████████▏ | 246/301 [05:19<01:11,  1.30s/it]

student_loss tensor(0.0653, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4114e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 82%|████████▏ | 247/301 [05:21<01:10,  1.31s/it]

student_loss tensor(0.0689, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3588e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 82%|████████▏ | 248/301 [05:22<01:09,  1.31s/it]

student_loss tensor(0.0644, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3795e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 83%|████████▎ | 249/301 [05:23<01:07,  1.31s/it]

student_loss tensor(0.0637, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4861e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 83%|████████▎ | 250/301 [05:25<01:06,  1.31s/it]

student_loss tensor(0.0678, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4531e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 83%|████████▎ | 251/301 [05:26<01:05,  1.30s/it]

student_loss tensor(0.0683, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3914e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 84%|████████▎ | 252/301 [05:27<01:04,  1.31s/it]

student_loss tensor(0.0745, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.3886e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 84%|████████▍ | 253/301 [05:28<01:02,  1.31s/it]

student_loss tensor(0.0663, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4041e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 84%|████████▍ | 254/301 [05:30<01:01,  1.31s/it]

student_loss tensor(0.0638, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4782e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 85%|████████▍ | 255/301 [05:31<01:00,  1.31s/it]

student_loss tensor(0.0645, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4889e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 85%|████████▌ | 256/301 [05:32<00:58,  1.31s/it]

student_loss tensor(0.0635, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5019e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 85%|████████▌ | 257/301 [05:34<00:57,  1.31s/it]

student_loss tensor(0.0686, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4452e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 86%|████████▌ | 258/301 [05:35<00:56,  1.31s/it]

student_loss tensor(0.0668, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4969e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 86%|████████▌ | 259/301 [05:36<00:54,  1.31s/it]

student_loss tensor(0.0687, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4098e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 86%|████████▋ | 260/301 [05:38<00:53,  1.31s/it]

student_loss tensor(0.0571, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5283e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 87%|████████▋ | 261/301 [05:39<00:52,  1.31s/it]

student_loss tensor(0.0617, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4735e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 87%|████████▋ | 262/301 [05:40<00:50,  1.31s/it]

student_loss tensor(0.0620, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5425e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 87%|████████▋ | 263/301 [05:42<00:49,  1.30s/it]

student_loss tensor(0.0702, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4284e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 88%|████████▊ | 264/301 [05:43<00:48,  1.30s/it]

student_loss tensor(0.0655, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.4811e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 88%|████████▊ | 265/301 [05:44<00:47,  1.31s/it]

student_loss tensor(0.0607, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5290e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 88%|████████▊ | 266/301 [05:45<00:45,  1.31s/it]

student_loss tensor(0.0663, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5047e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 89%|████████▊ | 267/301 [05:47<00:44,  1.32s/it]

student_loss tensor(0.0648, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5705e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 89%|████████▉ | 268/301 [05:48<00:44,  1.35s/it]

student_loss tensor(0.0593, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5671e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 89%|████████▉ | 269/301 [05:50<00:42,  1.34s/it]

student_loss tensor(0.0680, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5994e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 90%|████████▉ | 270/301 [05:51<00:41,  1.32s/it]

student_loss tensor(0.0630, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5739e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 90%|█████████ | 271/301 [05:52<00:39,  1.32s/it]

student_loss tensor(0.0643, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6105e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 90%|█████████ | 272/301 [05:53<00:37,  1.31s/it]

student_loss tensor(0.0655, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5347e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 91%|█████████ | 273/301 [05:55<00:36,  1.31s/it]

student_loss tensor(0.0689, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5059e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 91%|█████████ | 274/301 [05:56<00:35,  1.31s/it]

student_loss tensor(0.0605, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.5994e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 91%|█████████▏| 275/301 [05:57<00:34,  1.31s/it]

student_loss tensor(0.0662, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6108e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 92%|█████████▏| 276/301 [05:59<00:32,  1.31s/it]

student_loss tensor(0.0564, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7201e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 92%|█████████▏| 277/301 [06:00<00:31,  1.31s/it]

student_loss tensor(0.0569, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6924e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 92%|█████████▏| 278/301 [06:01<00:30,  1.31s/it]

student_loss tensor(0.0597, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6216e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 93%|█████████▎| 279/301 [06:03<00:28,  1.31s/it]

student_loss tensor(0.0532, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7488e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 93%|█████████▎| 280/301 [06:04<00:27,  1.30s/it]

student_loss tensor(0.0542, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7003e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 93%|█████████▎| 281/301 [06:05<00:26,  1.31s/it]

student_loss tensor(0.0592, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7012e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 94%|█████████▎| 282/301 [06:06<00:24,  1.30s/it]

student_loss tensor(0.0606, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7038e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 94%|█████████▍| 283/301 [06:08<00:23,  1.30s/it]

student_loss tensor(0.0528, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8137e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 94%|█████████▍| 284/301 [06:09<00:22,  1.30s/it]

student_loss tensor(0.0603, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6872e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 95%|█████████▍| 285/301 [06:10<00:20,  1.30s/it]

student_loss tensor(0.0610, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6629e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 95%|█████████▌| 286/301 [06:12<00:19,  1.30s/it]

student_loss tensor(0.0529, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8157e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 95%|█████████▌| 287/301 [06:13<00:18,  1.30s/it]

student_loss tensor(0.0555, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8071e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 96%|█████████▌| 288/301 [06:14<00:16,  1.30s/it]

student_loss tensor(0.0623, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6823e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 96%|█████████▌| 289/301 [06:16<00:15,  1.30s/it]

student_loss tensor(0.0592, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7557e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 96%|█████████▋| 290/301 [06:17<00:14,  1.30s/it]

student_loss tensor(0.0577, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7747e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 97%|█████████▋| 291/301 [06:18<00:12,  1.30s/it]

student_loss tensor(0.0547, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7210e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 97%|█████████▋| 292/301 [06:19<00:11,  1.30s/it]

student_loss tensor(0.0676, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.6271e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 97%|█████████▋| 293/301 [06:21<00:10,  1.30s/it]

student_loss tensor(0.0550, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7342e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 98%|█████████▊| 294/301 [06:22<00:09,  1.30s/it]

student_loss tensor(0.0524, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8667e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 98%|█████████▊| 295/301 [06:23<00:07,  1.31s/it]

student_loss tensor(0.0566, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8186e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 98%|█████████▊| 296/301 [06:25<00:06,  1.30s/it]

student_loss tensor(0.0518, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8599e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 99%|█████████▊| 297/301 [06:26<00:05,  1.30s/it]

student_loss tensor(0.0507, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8115e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 99%|█████████▉| 298/301 [06:27<00:03,  1.30s/it]

student_loss tensor(0.0557, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.7566e-05, device='cuda:0', grad_fn=<MeanBackward0>)


 99%|█████████▉| 299/301 [06:29<00:02,  1.30s/it]

student_loss tensor(0.0525, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8656e-05, device='cuda:0', grad_fn=<MeanBackward0>)


100%|█████████▉| 300/301 [06:30<00:01,  1.30s/it]

student_loss tensor(0.0523, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8240e-05, device='cuda:0', grad_fn=<MeanBackward0>)


100%|██████████| 301/301 [06:31<00:00,  1.30s/it]


epoch 0: perplexity: 9916841.442980865 : loss 16.109745025634766
saved model! epoch 0: perplexity: 9916841.442980865


  0%|          | 0/301 [00:00<?, ?it/s]

student_loss tensor(0.0491, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9527e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 1/301 [00:01<06:30,  1.30s/it]

student_loss tensor(0.0450, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9187e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|          | 2/301 [00:02<06:30,  1.31s/it]

student_loss tensor(0.0510, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8710e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|          | 3/301 [00:03<06:29,  1.31s/it]

student_loss tensor(0.0511, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8470e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|▏         | 4/301 [00:05<06:27,  1.31s/it]

student_loss tensor(0.0462, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9240e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 5/301 [00:06<06:25,  1.30s/it]

student_loss tensor(0.0444, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9391e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 6/301 [00:07<06:24,  1.30s/it]

student_loss tensor(0.0432, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9984e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 7/301 [00:09<06:23,  1.30s/it]

student_loss tensor(0.0466, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.8872e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 8/301 [00:10<06:22,  1.31s/it]

student_loss tensor(0.0433, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(8.9973e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 9/301 [00:11<06:20,  1.30s/it]

student_loss tensor(0.0457, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(9.0170e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 10/301 [00:13<06:20,  1.31s/it]

student_loss tensor(0.0457, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(9.0128e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▎         | 11/301 [00:14<06:18,  1.30s/it]

student_loss tensor(0.0426, device='cuda:0', grad_fn=<NllLossBackward0>)
distillation_loss tensor(9.0662e-05, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▍         | 12/301 [00:16<06:32,  1.36s/it]


In [None]:
# output_dir = "./savemodel/"
# save_path = f'models/{model.__class__.__name__}_add10b.pt'

# # Only show the progress bar once on each machine.
# progress_bar = tqdm(
#     range(max_train_steps), disable=not accelerator.is_local_main_process
# )
# completed_steps = 0
# best_val_perplexity = float("inf")

# for epoch in range(num_train_epochs):
#     model.train()
#     for step, batch in enumerate(train_dataloader):
#         optimizer.zero_grad()
#         outputs = model(**batch)
#         loss = outputs.loss
#         loss = loss / gradient_accumulation_steps
#         loss = loss.reshape(-1)
#         # accelerator.backward(loss)
#         if (
#             step % gradient_accumulation_steps == 0
#             or step == len(train_dataloader) - 1
#         ):
#             # Perform one optimization step with the PrivacyEngine
#             optimizer.step(loss=loss)
#             lr_scheduler.step()
#             # optimizer.zero_grad()
#             progress_bar.update(1)
#             completed_steps += 1

#         if completed_steps >= max_train_steps:
#             break

#     model.eval()
#     losses = []
#     for step, batch in enumerate(eval_dataloader):
#         with torch.no_grad():
#             outputs = model(**batch)

#         loss = outputs.loss
#         losses.append(
#             accelerator.gather(loss.repeat(per_device_eval_batch_size))
#         )

#     losses = torch.cat(losses)
#     losses = losses[: len(small_eval_dataset)]
#     try:
#         perplexity = math.exp(torch.mean(losses))
#     except OverflowError:
#         perplexity = float("inf")

#     # logger.info(f"epoch {epoch}: perplexity: {perplexity}")
#     print(f"epoch {epoch}: perplexity: {perplexity}")

#     # Printing epsilon from opacus privacy engine at the end of each epoch
#     eps, alpha = optimizer.privacy_engine.get_privacy_spent(delta)
#     print("End of epoch {}, we have epsilon {} for alpha {}".format(epoch, eps, alpha))

#     if perplexity < best_val_perplexity and output_dir is not None:
#         best_val_perplexity = perplexity
#     #     accelerator.wait_for_everyone()
#     #     unwrapped_model = accelerator.unwrap_model(model)
#     #     unwrapped_model.save_pretrained(
#     #         output_dir, save_function=accelerator.save
#     #     )
#         # logger.info(
#         #     f"saved model! epoch {epoch}: perplexity: {best_val_perplexity}"
#         # )
#         print(f"saved model! epoch {epoch}: perplexity: {best_val_perplexity}")
#         torch.save(model.state_dict(), save_path)
#         # tokenizer.save_pretrained(output_dir)
#         # if accelerator.is_main_process:
#         #     # tokenizer.save_pretrained(output_dir)
#         #     if push_to_hub:
#         #         repo.push_to_hub(
#         #             commit_message="Best val perplexity", auto_lfs_prune=True
#         #         )

#     # if push_to_hub and epoch < num_train_epochs - 1:
#     #     accelerator.wait_for_everyone()
#     #     unwrapped_model = accelerator.unwrap_model(model)
#     #     unwrapped_model.save_pretrained(
#     #         output_dir, save_function=accelerator.save
#     #     )
#     #     if accelerator.is_main_process:
#     #         tokenizer.save_pretrained(output_dir)
#     #         repo.push_to_hub(
#     #             commit_message=f"Training in progress epoch {epoch}",
#     #             blocking=False,
#     #             auto_lfs_prune=True,
#     #         )

#     # if epoch == (num_train_epochs - 1):
#     #     save_fir = output_dir + f"_epoch_{num_train_epochs - 1}"
#     #     accelerator.wait_for_everyone()
#     #     unwrapped_model = accelerator.unwrap_model(model)
#     #     unwrapped_model.save_pretrained(save_fir, save_function=accelerator.save)
#     #     tokenizer.save_pretrained(save_fir)

## Test

In [None]:
# save_path = f'models/{model.__class__.__name__}_add10b.pt'
# model.load_state_dict(torch.load(save_path,  map_location=device))
# perplexity = evaluate(model, test_dataloader)
# print(f'Test Perplexity: {perplexity}')

## Inference

In [None]:
# import torch
# from transformers import (
#     CONFIG_MAPPING,
#     MODEL_MAPPING,
#     AdamW,
#     AutoConfig,
#     AutoModelForCausalLM,
#     AutoTokenizer,
#     default_data_collator
# )
# from itertools import chain

# # Load the trained model
# # model_path = 'dp-gpt2-clm-model.pth'
# model_checkpoint = "gpt2"
# config = AutoConfig.from_pretrained(model_checkpoint)
# model = AutoModelForCausalLM.from_config(config)

# save_path = f'models/{student_model.__class__.__name__}_distill_nodp.pt'
# model.load_state_dict(torch.load(save_path))
# model = model.eval()

In [None]:
# Set the device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Set up the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
# input_ids = tokenizer.encode('My ID is ', return_tensors='pt').to(device)
# input_ids[0]

In [None]:
# log_interval = 10
# max_seq_len = 200
# temperature = 1

def generate(prompt, max_seq_len, temperature, model, tokenizer, device, seed=None):
    tokens = ""
    if seed is not None:
        torch.manual_seed(seed)
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with open('nodp-distill-gpt2-generated.txt', 'w') as output_files:
        model.eval()
        with torch.no_grad():  # no tracking history
            for i in range(max_seq_len):
                
                output = model(input_ids)
                word_weights = output[0].squeeze().div(temperature).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                input = torch.cat([input_ids, word_tensor], 1)
    
                word = tokenizer.decode(word_idx)
                tokens = tokens + word + ('\n' if i % 20 == 19 else '')
                output_files.write(word + ('\n' if i % 20 == 19 else ''))
    
                # if i % log_interval == 0:
                #     print('| Generated {}/{} words'.format(i, max_seq_len))
            # print(tokens)
    return tokens

In [None]:
prompt = 'my number is'
max_seq_len = 10
seed = 0
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, device, seed)
    print(f'{str(temperature)}\n{generation}\n')