In [35]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torchtext import data
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")

## Решаю задачу классификации смс-сообщений на спам и не спам

### Обработка текстовых данных

In [3]:
def get_vocab(X):
    words = [sentence.split() for sentence in X]
    text_field = data.Field()
    text_field.build_vocab(words, max_size=10000)
    return text_field

In [4]:
def pad(seq, maxlen):
    if len(seq) < maxlen:
        seq = seq + ['<pad>'] * (maxlen - len(seq))
    return seq

In [5]:
def to_indices(vocab, words):
    return [vocab.stoi[w] for w in words]

In [6]:
def to_dataset(x, y, teacher_output):
    torch_x = torch.tensor(x, dtype=torch.long)
    torch_y = torch.tensor(y, dtype=torch.float)
    if teacher_output is None:
        torch_teacher_output = torch.full_like(torch_y, 0)
    else:
        torch_teacher_output = teacher_output
    return TensorDataset(torch_x, torch_y, torch_teacher_output)

In [73]:
def read_and_preprocess_spam_data(path):
    X = []
    y = []
    maxlen = 0
    with open(os.getcwd() + path) as file:
        for line in file:
            words = line.split()
            y.append(0 if words[0] == 'ham' else 1)
            X.append(' '.join(words[1:]))
            maxlen = max(maxlen, len(words))
    return X, y, maxlen

### Модель учитель

в качестве учителя я взял предобученный Берт для классификации

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [8]:
bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [76]:
class TeacherModel(nn.Module):
    def __init__(self, teacher):
        super(TeacherModel, self).__init__()
        self.teacher = teacher
    
    def forward(self, x):
        return self.teacher(x)[0], 0 # здесь я возвращаю кортеж с фиктивным вторым элементом,
                                    # чтобы можно было дообучить Берт используя тот же код,
                                    # который я использую для обучения модели ученика

### Модель ученик

In [80]:
class BiLSTM(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, \
                bidirectional, dropout, num_layers):
        super(BiLSTM, self).__init__()
        
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        
        self.rnn = nn.LSTM(
                            input_size=embedding_dim, 
                            hidden_size=hidden_dim, 
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout
                        )
        
        self.label_prediction = nn.Linear(hidden_dim * 2, output_dim)
        self.teacher_prediction = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def init_state(self, batch_size):
        return torch.zeros(2 * self.num_layers, batch_size, self.hidden_dim), \
               torch.zeros(2 * self.num_layers, batch_size, self.hidden_dim)
    
    def forward(self, x):
        x = self.embedding(x)
        x = torch.transpose(x, dim0=1, dim1=0)
        x, hidden = self.rnn(x)
        hidden, cell = hidden
        hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        label_prediction = self.label_prediction(hidden)
        teacher_prediction = self.teacher_prediction(hidden)
        return label_prediction, teacher_prediction

### Функция потерь для дистилляции

In [77]:
class DistilLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(DistilLoss, self).__init__()
        self.alpha = alpha
    
    def forward(self, real_prediction, real_output, teacher_prediction=None, teacher_output=None):
        bce = nn.CrossEntropyLoss()
        mse = nn.MSELoss()
        prediction_loss = bce(real_prediction, torch.tensor(real_output, dtype=torch.long))
        if teacher_output is None:
            return prediction_loss # если учимся без учителя, то обычная кросс-энтропия
        teacher_loss = mse(teacher_prediction, teacher_output)
        return self.alpha * prediction_loss + (1 - self.alpha) * teacher_loss

### Обучение моделей

In [70]:
def train(model, field, X, y, maxlen, epochs=5, batch_size=64, teacher=None, alpha=0.5):
    X_split = [t.split() for t in X]
    X_pad = [pad(s, maxlen) for s in X_split]
    X_index = [to_indices(field.vocab, s) for s in X_pad]
    teacher_output = None
    if teacher: # делаю предподсчет выходов модели учителя, чтобы использовать их во время обучения
        print('calculating teacher output...')
        lines = [" ".join(s) for s in X_pad]
        inds = [tokenizer.encode(line.split(), add_special_tokens=False) for line in lines]
        inds = torch.tensor(inds)
        teacher_output = []
        for i in range(len(inds) // 20): # прогоняю батчами, потому что кернел падает, если считать все сразу
            result = teacher(inds[i * 20: (i + 1) * 20])[0].detach()
            teacher_output.append(result)
        teacher_output = torch.cat(teacher_output)
        print('finished calculating teacher output...')
        print()

    dataset = to_dataset(X_index, y, teacher_output)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    
    epoch_loss = []
    optimizer = optim.Adam(model.parameters())
    loss_function = DistilLoss(alpha=alpha)
    
    print('training started...')
    for e in range(epochs):
        losses = 0
        count = 0
        print(f'epoch {e}', end=' ')
        for X_batch, y_batch, y_teacher in dataloader:
            label_prediction, teacher_prediction = model(X_batch)
            teacher_output = None
            if teacher:
                teacher_output = y_teacher
            labels = torch.tensor(y_batch, dtype=torch.long)
            loss = loss_function(label_prediction, labels, teacher_prediction, teacher_output)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses += loss
            count += 1
        losses /= count
        print(losses)
        epoch_loss.append(losses)

### Функция для измерения точности модели

In [62]:
def measure_accuracy(model, field, X, y, maxlen):
    X_split = [t.split() for t in X]
    X_pad = [pad(s, maxlen) for s in X_split]
    X_index = [to_indices(field.vocab, s) for s in X_pad]
    dataset = to_dataset(X_index, y, None)
    dataloader = DataLoader(dataset, 1, shuffle=True)
    correct = 0
    count = 0
    for sample, label, _ in dataloader:
        prediction = torch.argmax(model(sample)[0])
        correct += int(prediction.item() == int(label))
        count += 1
    return correct / count

In [37]:
X, y, maxlen = read_and_preprocess_spam_data('dataset.txt')
field = get_vocab(X)
vocab_size = len(field.vocab.stoi.keys())

  
Обычный BiLSTM обучался очень хорошо на данном датасете, поэтом я решил выбрать всего 1000 примеров  

In [57]:
X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, test_size=0.8205)
print('train lenghts:', len(X_train), len(y_train))
print('test lengths:', len(X_test), len(y_test))

train lenghts: 1000 1000
test lengths: 4574 4574


In [58]:
model = BiLSTM(input_dim=vocab_size, 
               embedding_dim=64,
               hidden_dim=32, 
               output_dim=2,
               bidirectional=True,
               dropout=0.5,
               num_layers=1,
            )

Здесь я решил немного дообучить Берт на данном датасете

In [59]:
teacher = TeacherModel(bert)

In [60]:
train(teacher, field, X_train, y_train, maxlen, teacher=None, epochs=1, batch_size=50)

training started...
epoch 0 tensor(0.4848, grad_fn=<DivBackward0>)


#### Обучение обычной модели

In [61]:
train(model, field, X_train, y_train, maxlen, epochs=3, batch_size=50, teacher=None)

training started...
epoch 0 tensor(0.5167, grad_fn=<DivBackward0>)
epoch 1 tensor(0.4005, grad_fn=<DivBackward0>)
epoch 2 tensor(0.3491, grad_fn=<DivBackward0>)


In [63]:
print('simple model accuracy:', measure_accuracy(model, field, X_test, y_test, maxlen))

simple model accuracy: 0.8692610406646262


#### Теперь обучим модель с дистилляцией

In [64]:
distill_model = BiLSTM(input_dim=vocab_size, 
               embedding_dim=64,
               hidden_dim=32, 
               output_dim=2,
               bidirectional=True,
               dropout=0.5,
               num_layers=1,
            )

Коэффициент альфу я выбрал равным 0.1, то есть модель обучалась в большей степени на лоссе "подражания", чем на кросс-энтропии

In [78]:
train(distill_model, field, X_train, y_train, maxlen, epochs=3, batch_size=50, teacher=teacher, alpha=0.1)

calculating teacher output...
finished calculating teacher output...

training started...
epoch 0 tensor(0.0799, grad_fn=<DivBackward0>)
epoch 1 tensor(0.0597, grad_fn=<DivBackward0>)
epoch 2 tensor(0.0540, grad_fn=<DivBackward0>)


Как видно, несмотря на то, что модель в меньшей степени обучалась на кросс-энтропии, получилось добиться даже более высокой точности

In [79]:
print('distilled model accuracy:', measure_accuracy(distill_model, field, X_test, y_test, maxlen))

distilled model accuracy: 0.9007433318758199
