In [1]:
#comment this if you are not using AIT proxy...
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [2]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import torch.nn.functional as F
import random
import os
import numpy as np
import utils
# Set the random seed for reproducible experiments
random.seed(230)
torch.manual_seed(230)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


## 1. Load Dataset

In [3]:
from datasets import load_dataset

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
task_name = "sst2"
datasets = load_dataset("glue",task_name)
datasets

Found cached dataset glue (/root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

## 2. Preprocessing

In [4]:
# student = "distilroberta-base"
# teacher = "textattack/roberta-base-SST-2"
teacher = 'bert-base-uncased'

In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher)

In [6]:
# Labels
if task_name is not None:
    is_regression = task_name == "stsb"
    if not is_regression:
        label_list = datasets["train"].features["label"].names
        num_labels = len(label_list)
    else:
        num_labels = 1
else:
    # Trying to have good defaults here, don't hesitate to tweak to your needs.
    is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"]
    if is_regression:
        num_labels = 1
    else:
        # A useful fast method:
        # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
        label_list = datasets["train"].unique("label")
        label_list.sort()  # Let's sort it for determinism
        num_labels = len(label_list)
        
num_labels, is_regression

(2, False)

In [7]:
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer, PretrainedConfig
model_name_or_path = teacher
config = AutoConfig.from_pretrained(
    model_name_or_path, 
    num_labels=num_labels, 
    finetuning_task=task_name)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
) #student

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
label_to_id = None

if (
    model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
    and task_name is not None
    and not is_regression
):
    # Some have all caps in their config, some don't.
    label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
    if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
        label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
        
elif task_name is None and not is_regression:
    label_to_id = {v: i for i, v in enumerate(label_list)}
    
def tokenize_function(examples):
    sentence1_key, sentence2_key = task_to_keys[task_name]
    args = (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    result = tokenizer(*args, max_length=180, padding="max_length", truncation=True)
    if "label" in examples:
        if label_to_id is not None:
            # Map labels to IDs (not necessary for GLUE tasks)
            result["label"] = [label_to_id[l] for l in examples["label"]]
        else:
            # In all cases, rename the column to labels because the model will expect that.
            result["label"] = examples["label"]
    
    return result

tokenized_datasets = datasets.map(tokenize_function, batched=True)
tokenized_datasets

Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f8e4804eae872318.arrow


Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2b3d3889e534a7ce.arrow


DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [9]:
#Use list comprehension to extract non-None elements from the tuple
elements = [element for element in task_to_keys[task_name] if element is not None]
elements

['sentence']

In [10]:
tokenized_datasets = tokenized_datasets.remove_columns(elements + ["idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [11]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=55) #.select(range(10000))
small_eval_dataset = tokenized_datasets["validation_matched" if task_name == "mnli" else "validation"].shuffle(seed=55)
small_test_dataset = tokenized_datasets["test"].shuffle(seed=55)

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-23987b2f2f6a1fc6.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2c3fc5b8ce479e09.arrow


## 3. Dataloaders

In [12]:
from torch.utils.data import DataLoader
per_device_train_batch_size = 64
per_device_eval_batch_size = 32

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=per_device_train_batch_size)
val_dataloader = DataLoader(small_eval_dataset, batch_size=per_device_eval_batch_size)
test_dataloader = DataLoader(small_test_dataset, batch_size=per_device_eval_batch_size)

## 4. Model

In [13]:
# teacher model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher,
    config=config,
).to(device)
teacher_model.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [14]:
import torch
import torch.nn as nn
from transformers import BertConfig, BertForSequenceClassification

# # Create a configuration for a 6-layer BERT model
# config = BertConfig(
#     hidden_size=768,
#     num_hidden_layers=6,
#     num_attention_heads=12,
#     intermediate_size=3072,
#     hidden_dropout_prob=0.1,
#     attention_probs_dropout_prob=0.1,
# )

# # Instantiate a 6-layer BERT model
# student_model = BertForSequenceClassification(config).to(device)

# # Print the student model architecture
# print(student_model)

In [15]:
import math
def pseudo_uniform_selection(n, k):
    # reference : A Short Study on Compressing Decoder-Based Language Models 
    # https://arxiv.org/pdf/2110.08460.pdf
    # Require: n > k; n mod k = 0; n mod 2 = 0
    # assert n > k and n % k == 0 and n % 2 == 0, "Invalid input"
    
    step = math.floor(n / k)
    start = 0
    end = n - 1
    selection = []
    while start <= end:
        selection.append(start)
        selection.append(end)
        start += step
        end -= step
    selection.sort()
    return selection

# Select the layers to copy from the teacher model
teacher_layers = teacher_model.config.num_hidden_layers
student_layers = 6
teacher_layers_to_use = pseudo_uniform_selection(teacher_layers, student_layers)
print(teacher_layers_to_use)

[0, 2, 4, 7, 9, 11]


In [16]:
import torch
import torch.nn as nn
from transformers import BertConfig, BertForSequenceClassification

# Select the layers to copy from the teacher model
# teacher_layers_to_use = [num for num in range(teacher_model.config.num_hidden_layers) if num % 2 == 0]  # Indices of layers to copy
# print(len(teacher_layers_to_use))

# Define the configuration for the student model
student_config = BertConfig(
    hidden_size=768,
    num_hidden_layers=len(teacher_layers_to_use),  # Number of student layers
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)

# Initialize the student model architecture
student_model = BertForSequenceClassification(config=student_config)

# Copy teacher layers to student
for student_layer_idx, teacher_layer_idx in enumerate(teacher_layers_to_use):
    teacher_layer = teacher_model.bert.encoder.layer[teacher_layer_idx]
    student_layer = student_model.bert.encoder.layer[student_layer_idx]
    student_layer.load_state_dict(teacher_layer.state_dict())

# Now you can use the student model for further tasks
student_model.train()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [17]:
# # student model
# # student_model = AutoModelForSequenceClassification.from_pretrained(
# #     student,
# #     config=config,
# # )

# class StudentModel(nn.Module):
#     def __init__(self, teacher_model, num_classes):
#         super(StudentModel, self).__init__()
#         self.student_layers = nn.ModuleList([nn.Identity() if i % 2 == 0 else layer for i, layer in enumerate(teacher_model.bert.encoder.layer)])
#         self.fc = nn.Linear(768, num_classes)  # Modify num_classes accordingly

#     def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
#         x = teacher_model.bert.embeddings(input_ids, token_type_ids, attention_mask)
#         for layer in self.student_layers:
#             x = layer(x)
#         x = self.fc(x[:, 0, :])  # Assuming you want to classify sentence-level tasks

#         if labels is not None:
#             loss_fct = nn.CrossEntropyLoss()
#             loss = loss_fct(x.view(-1, self.num_labels), labels.view(-1))
#             return loss
#         else:
#             return x

# num_classes = 2  # Modify according to your classification task
# student_model = StudentModel(teacher_model, num_classes)

# student_model.train()

*Note teach model and student model still have same layers*

## 5. Training

![Differentially Private Knowledge Distillation (DPKD)
](images/DPKD.jpg)

### Optimizer 

In [18]:
from torch.optim import AdamW

# Define optimizer
teacher_optimizer = AdamW(teacher_model.parameters(), lr=5e-5)
student_optimizer = AdamW(student_model.parameters(), lr=6e-5)

### Accelerator

In [19]:
from accelerate import Accelerator

accelerator = Accelerator()

# teacher_model, teacher_optimizer, train_dataloader, val_dataloader = accelerator.prepare(
#     student_model, teacher_optimizer, train_dataloader, val_dataloader
# )
teacher_model = accelerator.prepare(teacher_model)
student_model, student_optimizer, train_dataloader, val_dataloader = accelerator.prepare(
    student_model, student_optimizer, train_dataloader, val_dataloader
)

In [20]:
from transformers import get_scheduler
import math

gradient_accumulation_steps = 1
num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )

num_train_epochs = 20
max_train_steps = num_train_epochs * num_update_steps_per_epoch

teacher_lr_scheduler = get_scheduler(
    "linear",
    optimizer=teacher_optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

student_lr_scheduler = get_scheduler(
    "linear",
    optimizer=student_optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

total_batch_size = (
        per_device_train_batch_size
        * accelerator.num_processes
        * gradient_accumulation_steps
    )

### Ghost clipping: memory saving differentially private learning
Turning on ghost clipping requires changing only 1 line. You should notice a drastic reduction in peak GPU memory usage once this is turned on, at a potential cost of slower training speed. One might find this especially useful when constrained to only use older GPUs with small VRAMs or fitting super large models.

In [21]:
# !pip install ml_swissknife
# !pip install opt_einsum

In [22]:
import transformers, torch
from private_transformers import PrivacyEngine
dp = True
if dp == True:
    #Student Model
    privacy_engine = PrivacyEngine(
        student_model,
        batch_size=per_device_train_batch_size,
        sample_size=len(datasets['train']),
        epochs=per_device_train_batch_size,
        max_grad_norm=0.1,
        target_epsilon=3,
        clipping_mode="ghost",  # The only change you need to make!
    )
    privacy_engine.attach(student_optimizer)
    #Teacher Model
    privacy_engine = PrivacyEngine(
        teacher_model,
        batch_size=per_device_train_batch_size,
        sample_size=len(datasets['train']),
        epochs=per_device_train_batch_size,
        max_grad_norm=0.1,
        target_epsilon=3,
        clipping_mode="ghost",  # The only change you need to make!
    )
    privacy_engine.attach(teacher_optimizer)
else :
    privacy_engine = None

privacy_engine

PrivacyEngine(
  target_epsilon=3.000000, 
  target_delta=0.000005, 
  noise_multiplier=0.730322, 
  effective_noise_multiplier=0.011411, 
  epochs=64, 
  max_grad_norm=0.1, 
  sample_rate=0.0009502739461610417, 
  batch_size=64, 
  accounting_mode=rdp, 
  clipping_mode=ghost
)

In [23]:
# !pip install ml_swissknife

### Loss Objective 
The Kullback-Leibler divergence loss. For tensors of the same shape $y_{pred}, y_{true}$ where $y_{pred}$ is the input and $y_{true}$ â€‹ is the target, we define the pointwise KL-divergence as 

$$L(y_{pred}, y_{true}) = y_{pred}\cdot \log \frac{y_{true}}{y_{pred}}  = y_{true} \cdot (\log y_{true} -\log y_{true})$$

format : torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)
more infomation click [link](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)

In [24]:
def loss_fn_kd(student_outputs, labels, teacher_outputs, alpha = 0.9, T = 1):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    # student_outputs.logits.shape = (batch_size, class)
    # teacher_outputs.logits.shape = (batch_size, class)
    # labels.shape = (batch_size, )
    
    loss_fn = nn.KLDivLoss(reduction="none")
    kd_loss = loss_fn(F.log_softmax(student_outputs.logits/T, dim=1),
                             F.softmax(teacher_outputs.logits/T, dim=1) * (T ** 2)) #(batch_size, labels)
    
    kd_loss = kd_loss.mean(dim=1) #(batch_size, )
    # print(student_outputs.logits.shape, labels.shape)
    CELoss = F.cross_entropy(student_outputs.logits, labels, reduction="none") #.mean(dim=1) #(batch_size, )
    total_losses = alpha * CELoss + (1. - alpha) * kd_loss

    # total_losses = CELoss + alpha * kd_loss
    return total_losses.mean(dim=-1)

### Metrics

In [25]:
def accuracy(outputs, labels):
    """
    Compute the accuracy, given the outputs and labels for all images.

    Args:
        outputs: (np.ndarray) output of the model
        labels: (np.ndarray) [0, 1, ..., num_classes-1]

    Returns: (float) accuracy in [0,1]
    """
    # outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs==labels)/float(labels.size)


# maintain all metrics required in this dictionary- these are used in the training and evaluation loops
metrics = {
    'accuracy': accuracy,
    # could add more metrics such as accuracy for each token type
}

class RunningAverage():
    """A simple class that maintains the running average of a quantity
    
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """
    def __init__(self):
        self.steps = 0
        self.total = 0
    
    def update(self, val):
        self.total += val
        self.steps += 1
    
    def __call__(self):
        return self.total/float(self.steps)

In [26]:
#checking chucking
for i in train_dataloader:
    print(i['input_ids'].shape, i['labels'].shape)
    break
for i in val_dataloader:
    print(i['input_ids'].shape, i['labels'].shape)
    break

torch.Size([64, 180]) torch.Size([64])
torch.Size([32, 180]) torch.Size([32])


In [27]:
# Defining train_kd & train_and_evaluate_kd functions
def train_kd(student_model, teacher_model, optimizer, loss_fn_kd, train_dataloader, metrics):
    # set model to training mode
    student_model.train()
    teacher_model.eval()
    
    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = RunningAverage()
    
    for step, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        # batch = tuple(t.to(device) for t in batch)

        # compute model output, fetch teacher output, and compute KD loss
        output_batch = student_model(**batch) # output = loss, logits, hidden_states, attentions
        labels_batch = batch['labels']
        
        # get one batch output from teacher_outputs list
        with torch.no_grad():
            output_teacher_batch = teacher_model(**batch)

        loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch)
        loss = loss.reshape(-1)
        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        # loss.backward()
        # accelerator.backward(loss)
        
        # This step is different from existing workflows: 
        # Don't call `loss.backward`; leave it to `optimizer.step` to handle backward.
        # performs updates using calculated gradients
        # `loss` is a 1-D tensor of shape (batch_size,).
        optimizer.step(loss=loss)

        # Evaluate summaries only once in a while
        if step % gradient_accumulation_steps == 0:
            # extract data from torch, move to cpu, convert to numpy arrays            
            output_batch = output_batch.logits.argmax(dim=-1).cpu().numpy()
            labels_batch = labels_batch.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {metric:metrics[metric](output_batch, labels_batch)
                             for metric in metrics}
            summary_batch['loss'] = loss.item()
            summ.append(summary_batch)

        # update the average loss
        loss_avg.update(loss.item())
        
    # compute mean of all metrics in summary
    metrics_mean = {metric:np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())

In [28]:
def evaluate_kd(model, eval_dataloader, metrics):
    # set model to evaluation mode
    model.eval()
    # summary for current eval loop
    summ = []

    # compute metrics over the dataset
    for step, batch in enumerate(eval_dataloader):
        
        # compute model output
        output_batch = model(**batch)
        labels_batch = batch['labels']
        
        loss = 0.0  #force validation loss to zero to reduce computation time

        # extract data from torch, move to cpu, convert to numpy arrays
        output_batch = output_batch.logits.argmax(dim=-1).cpu().numpy()
        labels_batch = labels_batch.cpu().numpy()

        # compute all metrics on this batch
        summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                         for metric in metrics}
        # summary_batch['loss'] = loss.item()
        summary_batch['loss'] = loss
        summ.append(summary_batch)

    # compute mean of all metrics in summary
    metrics_mean = {metric:np.mean([x[metric] for x in summ]) for metric in summ[0]} 
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    print("- Eval metrics : " + metrics_string)
    return metrics_mean

In [29]:
model_dir = './experiments/distill'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [30]:
def train_and_evaluate_kd(student_model, teacher_model, train_dataloader, val_dataloader, 
                          optimizer, loss_fn_kd, metrics, model_dir, save_path, restore_file=None):
    best_val_acc = 0.0
    for epoch in range(num_train_epochs):
        student_lr_scheduler.step()
        
         # compute number of batches in one epoch (one full pass over the training set)
        train_kd(student_model, teacher_model, optimizer, loss_fn_kd, train_dataloader,
                 metrics)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate_kd(student_model, val_dataloader, metrics)

        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc
        
        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)
        print(f'epoch: {epoch + 1}')
        # If best_eval, best_save_path
        if is_best:
            print("- Found new best accuracy")
            best_val_acc = val_acc
            
            print(f"saved model! epoch {epoch}: best accuracy: {best_val_acc}")
            torch.save(model.state_dict(), save_path)
            
        #     # Save best val metrics in a json file in the model directory
        #     best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
        #     utils.save_dict_to_json(val_metrics, best_json_path)

        # # Save latest val metrics in a json file in the model directory
        # last_json_path = os.path.join(model_dir, "metrics_val_last_weights.json")
        # utils.save_dict_to_json(val_metrics, last_json_path)

In [31]:
save_path = f'models/{student_model.__class__.__name__}distillbert_pseudo.pt'
train_and_evaluate_kd(student_model, teacher_model, train_dataloader, val_dataloader, student_optimizer,
                       loss_fn_kd, metrics, model_dir, save_path, restore_file=None)

  0%|          | 0/1053 [00:00<?, ?it/s]



- Eval metrics : accuracy: 0.770 ; loss: 0.000
Checkpoint Directory exists! 
epoch: 1
- Found new best accuracy
saved model! epoch 0: best accuracy: 0.7700892857142857


  0%|          | 0/1053 [00:00<?, ?it/s]

- Eval metrics : accuracy: 0.812 ; loss: 0.000
Checkpoint Directory exists! 
epoch: 2
- Found new best accuracy
saved model! epoch 1: best accuracy: 0.8125


  0%|          | 0/1053 [00:00<?, ?it/s]

- Eval metrics : accuracy: 0.814 ; loss: 0.000
Checkpoint Directory exists! 
epoch: 3
- Found new best accuracy
saved model! epoch 2: best accuracy: 0.8136160714285714


  0%|          | 0/1053 [00:00<?, ?it/s]

- Eval metrics : accuracy: 0.809 ; loss: 0.000
Checkpoint Directory exists! 
epoch: 4


  0%|          | 0/1053 [00:00<?, ?it/s]

- Eval metrics : accuracy: 0.791 ; loss: 0.000
Checkpoint Directory exists! 
epoch: 5


  0%|          | 0/1053 [00:00<?, ?it/s]

## 6. Evaluation

In [None]:
# import numpy as np
# import evaluate
# import torch

# def testing(model, dataloader):
#     metric = evaluate.load("glue", "sst2")
#     model.eval()
#     for batch in dataloader:
#         # batch = {k: v.to(device) for k, v in batch.items()}
#         with torch.no_grad():
#             outputs = model(**batch)
        
#         loss        = outputs.loss
#         logits      = outputs.logits
#         predictions = torch.argmax(logits, dim=-1)
#         metric.add_batch(predictions=predictions, references=batch["labels"])
        
#     return metric.compute(), float(loss)

### Teacher Model Accuracy

In [None]:
# teacher_checkpoint = 'experiments/best.pth.tar'
# utils.load_checkpoint(teacher_checkpoint, teacher_model)
# metric, loss = testing(teacher_model, test_dataloader)
# metric, loss

### Student Model Accuracy

In [None]:
# metric, loss = testing(student_model, test_dataloader)
# metric, loss

## reference : 
1. https://github.com/haitongli/knowledge-distillation-pytorch/blob/master/train.py
2. https://github.com/lxuechen/private-transformers