In [29]:
from transformers import AutoTokenizer
kshot = '16-100'
teacher_id = "clean-model\\checkpoint-24-epoch-6"
student_id = "distilroberta-base"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id, use_fast=False)
student_tokenizer = AutoTokenizer.from_pretrained(student_id, use_fast=False)

# sample input
sample = "This is a basic example, with different words to test."

# assert results
assert teacher_tokenizer(sample) == student_tokenizer(sample), "Tokenizers haven't created the same output"

In [30]:
#pre-processing and tokenization
import pandas as pd
from datasets import Dataset, DatasetDict

def load_data(fileName):
    df = pd.read_csv(f'BadPrompt\\data\\k-shot\\subj\\{kshot}\\{fileName}')
    df['idx'] = range(1, len(df) + 1)
    df = df.reset_index(drop=True)
    df.columns = ['label', 'sentence', 'idx']
    df = df.reindex(columns=['sentence', 'label', 'idx'])
    ds = Dataset.from_pandas(df)
    ds = ds.class_encode_column("label")
    return ds

def process(examples):
    tokenized_inputs = teacher_tokenizer(
        examples["sentence"], truncation=True, padding=True
    )
    return tokenized_inputs

dataset = DatasetDict()

dataset['train'] = load_data('train.csv')
dataset['validation'] = load_data('dev.csv')
dataset['test'] = load_data('test.csv')

tokenized_datasets = dataset.map(process, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label","labels")

tokenized_datasets["train"].features["labels"]

                                                                        

ClassLabel(names=['0', '1'], id=None)

In [31]:
#distill class
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.alpha = alpha
        self.temperature = temperature
        
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        # compute student output
        outputs_student = model(**inputs)
        student_loss=outputs_student.loss
        # compute teacher output
        with torch.no_grad():
          outputs_teacher = self.teacher(**inputs)
        
        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()
        
        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss

In [32]:
#Hyperparameter definition, model loading
from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding

output_dir = f'distilled-model\\distilroberta-subj-{kshot}'
# create label2id, id2label dicts for nice outputs for the model
labels = tokenized_datasets["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# define training args
training_args = DistillationTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=15,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    fp16=True,
    learning_rate=6e-5,
    seed=33,
    # logging & evaluation strategies
    logging_dir=f"{output_dir}\\logs",
    logging_strategy="epoch", # to get more information to TB
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    # distilation parameters
    alpha=0.5,
    temperature=4.0
    )

# define data_collator
data_collator = DataCollatorWithPadding(tokenizer=teacher_tokenizer)

# define model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=num_labels, 
    id2label=id2label,
    label2id=label2id,
)

# define student model
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    num_labels=num_labels, 
    id2label=id2label,
    label2id=label2id,
)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.weight'

In [33]:
import evaluate
import numpy as np

# define metrics and metrics function
accuracy_metric = evaluate.load( "accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc = accuracy_metric.compute(predictions=predictions, references=labels)
    return {
        "accuracy": acc["accuracy"],
    }

In [34]:
#training
trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=teacher_tokenizer,
    compute_metrics=compute_metrics,
)
trainer.use_cuda_amp = True

In [35]:
trainer.train()

  7%|▋         | 1/15 [00:00<00:04,  3.19it/s]

{'loss': 2.1465, 'learning_rate': 5.6e-05, 'epoch': 1.0}



  7%|▋         | 1/15 [00:00<00:04,  3.19it/s]

{'eval_loss': 2.2808144092559814, 'eval_accuracy': 0.5161290322580645, 'eval_runtime': 0.0835, 'eval_samples_per_second': 371.044, 'eval_steps_per_second': 11.969, 'epoch': 1.0}


 13%|█▎        | 2/15 [00:01<00:14,  1.11s/it]

{'loss': 2.1048, 'learning_rate': 5.2000000000000004e-05, 'epoch': 2.0}



 13%|█▎        | 2/15 [00:02<00:14,  1.11s/it]

{'eval_loss': 2.2565524578094482, 'eval_accuracy': 0.5161290322580645, 'eval_runtime': 0.0846, 'eval_samples_per_second': 366.452, 'eval_steps_per_second': 11.821, 'epoch': 2.0}


 20%|██        | 3/15 [00:03<00:16,  1.38s/it]

{'loss': 2.0442, 'learning_rate': 4.8e-05, 'epoch': 3.0}



 20%|██        | 3/15 [00:03<00:16,  1.38s/it]

{'eval_loss': 2.217442274093628, 'eval_accuracy': 0.5161290322580645, 'eval_runtime': 0.0855, 'eval_samples_per_second': 362.417, 'eval_steps_per_second': 11.691, 'epoch': 3.0}


 27%|██▋       | 4/15 [00:05<00:17,  1.57s/it]

{'loss': 2.0203, 'learning_rate': 4.4e-05, 'epoch': 4.0}



 27%|██▋       | 4/15 [00:05<00:17,  1.57s/it]

{'eval_loss': 2.1455793380737305, 'eval_accuracy': 0.5161290322580645, 'eval_runtime': 0.086, 'eval_samples_per_second': 360.455, 'eval_steps_per_second': 11.628, 'epoch': 4.0}


 33%|███▎      | 5/15 [00:07<00:16,  1.64s/it]

{'loss': 1.9184, 'learning_rate': 3.9999999999999996e-05, 'epoch': 5.0}



 33%|███▎      | 5/15 [00:07<00:16,  1.64s/it]

{'eval_loss': 2.010833740234375, 'eval_accuracy': 0.6451612903225806, 'eval_runtime': 0.0825, 'eval_samples_per_second': 375.593, 'eval_steps_per_second': 12.116, 'epoch': 5.0}


 40%|████      | 6/15 [00:09<00:15,  1.68s/it]

{'loss': 1.7491, 'learning_rate': 3.6e-05, 'epoch': 6.0}



 40%|████      | 6/15 [00:09<00:15,  1.68s/it]

{'eval_loss': 1.7874598503112793, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0836, 'eval_samples_per_second': 371.003, 'eval_steps_per_second': 11.968, 'epoch': 6.0}


 47%|████▋     | 7/15 [00:10<00:13,  1.71s/it]

{'loss': 1.4591, 'learning_rate': 3.2e-05, 'epoch': 7.0}



 47%|████▋     | 7/15 [00:10<00:13,  1.71s/it]

{'eval_loss': 1.5193419456481934, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.082, 'eval_samples_per_second': 378.036, 'eval_steps_per_second': 12.195, 'epoch': 7.0}


 53%|█████▎    | 8/15 [00:12<00:12,  1.73s/it]

{'loss': 1.2157, 'learning_rate': 2.8e-05, 'epoch': 8.0}



 53%|█████▎    | 8/15 [00:12<00:12,  1.73s/it]

{'eval_loss': 1.2998932600021362, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0825, 'eval_samples_per_second': 375.711, 'eval_steps_per_second': 12.12, 'epoch': 8.0}


 60%|██████    | 9/15 [00:14<00:10,  1.73s/it]

{'loss': 1.0469, 'learning_rate': 2.4e-05, 'epoch': 9.0}



 60%|██████    | 9/15 [00:14<00:10,  1.73s/it]

{'eval_loss': 1.1088396310806274, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0815, 'eval_samples_per_second': 380.29, 'eval_steps_per_second': 12.267, 'epoch': 9.0}


 67%|██████▋   | 10/15 [00:16<00:08,  1.75s/it]

{'loss': 0.882, 'learning_rate': 1.9999999999999998e-05, 'epoch': 10.0}



 67%|██████▋   | 10/15 [00:16<00:08,  1.75s/it]

{'eval_loss': 0.9221992492675781, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.083, 'eval_samples_per_second': 373.384, 'eval_steps_per_second': 12.045, 'epoch': 10.0}


 73%|███████▎  | 11/15 [00:17<00:07,  1.75s/it]

{'loss': 0.7255, 'learning_rate': 1.6e-05, 'epoch': 11.0}



 73%|███████▎  | 11/15 [00:18<00:07,  1.75s/it]

{'eval_loss': 0.7745137810707092, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0841, 'eval_samples_per_second': 368.736, 'eval_steps_per_second': 11.895, 'epoch': 11.0}


 80%|████████  | 12/15 [00:19<00:05,  1.76s/it]

{'loss': 0.6339, 'learning_rate': 1.2e-05, 'epoch': 12.0}



 80%|████████  | 12/15 [00:19<00:05,  1.76s/it]

{'eval_loss': 0.6806623935699463, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0837, 'eval_samples_per_second': 370.358, 'eval_steps_per_second': 11.947, 'epoch': 12.0}


 87%|████████▋ | 13/15 [00:21<00:03,  1.76s/it]

{'loss': 0.578, 'learning_rate': 8e-06, 'epoch': 13.0}



 87%|████████▋ | 13/15 [00:21<00:03,  1.76s/it]

{'eval_loss': 0.6051871180534363, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.082, 'eval_samples_per_second': 378.017, 'eval_steps_per_second': 12.194, 'epoch': 13.0}


 93%|█████████▎| 14/15 [00:23<00:01,  1.78s/it]

{'loss': 0.5176, 'learning_rate': 4e-06, 'epoch': 14.0}



 93%|█████████▎| 14/15 [00:23<00:01,  1.78s/it]

{'eval_loss': 0.5533555746078491, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.0875, 'eval_samples_per_second': 354.193, 'eval_steps_per_second': 11.426, 'epoch': 14.0}


100%|██████████| 15/15 [00:25<00:00,  1.78s/it]

{'loss': 0.4671, 'learning_rate': 0.0, 'epoch': 15.0}



100%|██████████| 15/15 [00:25<00:00,  1.78s/it]

{'eval_loss': 0.5277326703071594, 'eval_accuracy': 0.9354838709677419, 'eval_runtime': 0.083, 'eval_samples_per_second': 373.483, 'eval_steps_per_second': 12.048, 'epoch': 15.0}


100%|██████████| 15/15 [00:26<00:00,  1.79s/it]

{'train_runtime': 26.7841, 'train_samples_per_second': 17.361, 'train_steps_per_second': 0.56, 'train_loss': 1.3006048560142518, 'epoch': 15.0}





TrainOutput(global_step=15, training_loss=1.3006048560142518, metrics={'train_runtime': 26.7841, 'train_samples_per_second': 17.361, 'train_steps_per_second': 0.56, 'train_loss': 1.3006048560142518, 'epoch': 15.0})