In [1]:
from torch.utils.data import DataLoader
import torch
from torch.nn import KLDivLoss
from torch.optim import Adam
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_from_disk

# Check if a GPU is available and if not, use a CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pre-trained BERT model and tokenizer
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load TinyBERT model
student_model = BertForSequenceClassification.from_pretrained("huawei-noah/TinyBERT_General_4L_312D", num_labels=2).to(device)

# Define the loss function and optimizer
loss_function = KLDivLoss(reduction='batchmean').to(device)
optimizer = Adam(student_model.parameters(), lr=1e-3)

dataset = load_from_disk('yelp_dataset')

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./general/results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=128,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./general/logs',            # directory for storing logs
)

# Define the temperature
temperature = 2.0

# Define the training function
def train_step(model, inputs):
    model.train()
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    labels = inputs['label'].to(device)

    # Forward pass through the teacher model
    with torch.no_grad():
        teacher_output = teacher_model(input_ids, attention_mask=attention_mask)

    # Forward pass through the student model
    student_output = model(input_ids, attention_mask=attention_mask)

    # Calculate the loss
    loss = loss_function(
        torch.log_softmax(student_output.logits / temperature, dim=-1),
        torch.softmax(teacher_output.logits / temperature, dim=-1),
    )

    return loss

# Define the trainer
trainer = Trainer(
    model=student_model,
    args=training_args,
    train_dataset=dataset,
    compute_metrics=train_step,
)

# Start the training
trainer.train()

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
500,0.3608
1000,0.182
1500,0.158
2000,0.1419
2500,0.1394
3000,0.1334
3500,0.1302
4000,0.1253
4500,0.1173
5000,0.1061


TrainOutput(global_step=13125, training_loss=0.11493837273007347, metrics={'train_runtime': 14333.2324, 'train_samples_per_second': 117.21, 'train_steps_per_second': 0.916, 'total_flos': 2.408951365632e+16, 'train_loss': 0.11493837273007347, 'epoch': 3.0})

In [2]:
trainer.save_model('tinybert-gkd')