In [None]:
!pip install datasets transformers

In [25]:
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import random

In [None]:
 torch.manual_seed(2023)

**Подготовка данных**

In [None]:
dataset = load_dataset("imdb")

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
def tokenize_batch(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)

In [31]:
train_data, test_data = train_test_split(dataset['train'].select(range(5001)), test_size=0.2)
train_data_enc = tokenize_batch(train_data)
test_data_enc = tokenize_batch(test_data)

**обучение учителя (трансформера)**

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
model_teacher = BertForSequenceClassification.from_pretrained('bert-base-uncased').to(device)
optimizer_teacher = AdamW(model_teacher.parameters(), lr=2e-5)

train_dataset = TensorDataset(
    torch.tensor(train_data_enc["input_ids"]),
    torch.tensor(train_data_enc["attention_mask"]),
    torch.tensor(train_data["label"])
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [38]:
model_teacher.train()
for epoch in range(3):
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{3}', leave=False)
    for batch in progress_bar:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer_teacher.zero_grad()
        outputs = model_teacher(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer_teacher.step()



**Обучение ученика LSTM**

In [55]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        output = self.fc(output[:, -1, :])
        return output

In [56]:
train_data_enc_lstm = tokenizer(train_data["text"], return_tensors="pt", padding=True, truncation=True)
train_labels_lstm = torch.tensor(train_data["label"])

In [57]:
train_dataset_lstm = TensorDataset(
    train_data_enc_lstm["input_ids"],
    train_data_enc_lstm["attention_mask"],
    train_labels_lstm
)

train_loader_lstm = DataLoader(train_dataset_lstm, batch_size=8, shuffle=True, drop_last=True)

In [58]:
model_student = SimpleLSTM(input_size=len(tokenizer), hidden_size=128, output_size=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer_student = optim.Adam(model_student.parameters(), lr=0.001)

In [59]:
model_student.train()
for epoch in range(3):
    progress_bar = tqdm(train_loader_lstm, desc=f'Epoch {epoch + 1}/{3}', leave=False)
    for batch in progress_bar:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

        optimizer_student.zero_grad()
        outputs = model_student(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_student.step()

        progress_bar.set_postfix(loss=loss.item())



**Дистилляция модели**

In [60]:
def distillation_loss(logits_teacher, logits_student, temperature=1.0):
    soft_teacher = torch.softmax(logits_teacher / temperature, dim=1)
    soft_student = torch.softmax(logits_student / temperature, dim=1)
    return nn.KLDivLoss()(torch.log(soft_student), soft_teacher)

In [69]:
model_teacher.eval()
input_ids_teacher = input_ids.to(device)
attention_mask_teacher = attention_mask.to(device)

with torch.no_grad():
    logits_teacher = model_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits

# Дистилляция по батчам
batch_size = 8
num_batches = len(train_loader_lstm)
loss_distillation_total = 0.0

model_student.train()

for batch in tqdm(train_loader_lstm, desc='Distillation', leave=False):
    input_ids, attention_mask, labels = batch
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

    # Получение logits от ученика
    logits_student = model_student(input_ids)

    # Вычисление дистилляционной потери для текущего батча
    loss_distillation_batch = distillation_loss(logits_teacher, logits_student)

    # Суммирование потерь по батчам
    loss_distillation_total += loss_distillation_batch.item()

# Вычисление средней дистилляционной потери
avg_loss_distillation = loss_distillation_total / num_batches
print(avg_loss_distillation)

                                                                

1.3548936418374069e-05




**Сравнение качества моделей**

In [None]:
model_teacher.eval()
model_student.eval()

In [72]:
# Оценка качества учителя на тестовых данных
with torch.no_grad():
    logits_teacher_test = []
    for i in range(0, len(test_data_enc["input_ids"]), batch_size):
        input_ids_batch = torch.tensor(test_data_enc["input_ids"][i:i+batch_size]).to(device)
        attention_mask_batch = torch.tensor(test_data_enc["attention_mask"][i:i+batch_size]).to(device)

        logits_batch = model_teacher(input_ids_batch, attention_mask=attention_mask_batch).logits
        logits_teacher_test.append(logits_batch)

    logits_teacher_test = torch.cat(logits_teacher_test, dim=0)

    predictions_teacher = torch.argmax(logits_teacher_test, dim=1)
    accuracy_teacher = (predictions_teacher == torch.tensor(test_data["label"]).to(device)).float().mean().item()
    print(f"Accuracy Teacher: {accuracy_teacher}")

# Оценка качества ученика на тестовых данных
with torch.no_grad():
    logits_student_test = []
    for i in range(0, len(test_data_enc["input_ids"]), batch_size):
        input_ids_batch = torch.tensor(test_data_enc["input_ids"][i:i+batch_size]).to(device)

        logits_batch = model_student(input_ids_batch)
        logits_student_test.append(logits_batch)

    logits_student_test = torch.cat(logits_student_test, dim=0)

    predictions_student = torch.argmax(logits_student_test, dim=1)
    accuracy_student = (predictions_student == torch.tensor(test_data["label"]).to(device)).float().mean().item()
    print(f"Accuracy Student: {accuracy_student}")

# Оценка качества дистиллированной модели на тестовых данных
with torch.no_grad():
    logits_distilled_test = []
    for i in range(0, len(test_data_enc["input_ids"]), batch_size):
        input_ids_batch = torch.tensor(test_data_enc["input_ids"][i:i+batch_size]).to(device)

        logits_batch = model_student(input_ids_batch)
        logits_distilled_test.append(logits_batch)

    logits_distilled_test = torch.cat(logits_distilled_test, dim=0)

    predictions_distilled = torch.argmax(logits_distilled_test, dim=1)
    accuracy_distilled = (predictions_distilled == torch.tensor(test_data["label"]).to(device)).float().mean().item()
    print(f"Accuracy Distilled: {accuracy_distilled}")


Accuracy Teacher: 1.0
Accuracy Student: 1.0
Accuracy Distilled: 1.0
