In [1]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torchtext import data
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
import random
import warnings
warnings.filterwarnings("ignore")

In [2]:
os.chdir('/Users/george/Documents/Distillation/')
os.getcwd()

'/Users/george/Documents/Distillation'

## Решаю задачу классификации смс-сообщений на спам и не спам

### Обработка текстовых данных

In [3]:
def get_vocab(X):
    words = [sentence.split() for sentence in X]
    text_field = data.Field()
    text_field.build_vocab(words, max_size=10000)
    return text_field

In [4]:
def pad(seq, maxlen):
    if len(seq) < maxlen:
        seq = seq + ['<pad>'] * (maxlen - len(seq))
    return seq

In [5]:
def to_indices(vocab, words):
    return [vocab.stoi[w] for w in words]

In [6]:
def to_dataset(X, y, field):
    X_split = [t.split() for t in X]
    X_pad = [pad(s, maxlen) for s in X_split]
    X_index = [to_indices(field.vocab, s) for s in X_pad]
    torch_x = torch.tensor(X_index, dtype=torch.long)
    torch_y = torch.tensor(y, dtype=torch.float)
    return TensorDataset(torch_x, torch_y)

In [7]:
def to_dataset_distill(X, y, teacher_output, field):
    X_split = [t.split() for t in X]
    X_pad = [pad(s, maxlen) for s in X_split]
    X_index = [to_indices(field.vocab, s) for s in X_pad]
    torch_x = torch.tensor(X_index, dtype=torch.long)
    torch_y = torch.tensor(y, dtype=torch.float)
    return TensorDataset(torch_x, torch_y, teacher_output)

In [8]:
def to_dataset_for_bert(X, y, tokenizer):
    X_split = [t.split() for t in X]
    text = [pad(s, maxlen) for s in X_split]
    lines = [" ".join(s) for s in text]
    masks = [[int(word != '<pad>') for word in sentence] for sentence in text]
    inds = [tokenizer.encode(line.split(), add_special_tokens=False) for line in lines]
    inds = torch.tensor(inds)
    masks = torch.tensor(masks, dtype=torch.int8)
    torch_y = torch.tensor(y, dtype=torch.float)
    return TensorDataset(inds, torch_y, masks)

In [9]:
def read_and_preprocess_spam_data(path):
    X = []
    y = []
    maxlen = 0
    with open(os.getcwd() + path) as file:
        for line in file:
            words = line.split()
            y.append(0 if words[0] == 'ham' else 1)
            X.append(' '.join(words[1:]))
            maxlen = max(maxlen, len(words))
    return X, y, maxlen

### Модель учитель

в качестве учителя я взял предобученный Берт для классификации

In [565]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [566]:
bert = BertModel.from_pretrained('bert-base-uncased')

### Функция потерь для дистилляции

In [639]:
class DistillLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(DistillLoss, self).__init__()
        self.alpha = alpha
    
    def forward(self, real_prediction, real_output, teacher_prediction, teacher_output):
        bce = nn.CrossEntropyLoss()
        mse = nn.MSELoss()
        prediction_loss = bce(real_prediction, torch.tensor(real_output, dtype=torch.long))
        teacher_loss = mse(teacher_prediction, teacher_output)
        return self.alpha * prediction_loss + (1 - self.alpha) * teacher_loss

### BiLSTM

In [555]:
class BiLSTM(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, \
                bidirectional, dropout, num_layers):
        super(BiLSTM, self).__init__()
        
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        
        self.rnn = nn.LSTM(
                            input_size=embedding_dim, 
                            hidden_size=hidden_dim, 
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout
                        )
        
        self.label_prediction = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def init_state(self, batch_size):
        return torch.zeros(2 * self.num_layers, batch_size, self.hidden_dim), \
               torch.zeros(2 * self.num_layers, batch_size, self.hidden_dim)
    
    def forward(self, x):
        x = self.embedding(x)
        x = torch.transpose(x, dim0=1, dim1=0)
        x, hidden = self.rnn(x)
        hidden, cell = hidden
        hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        label_prediction = self.label_prediction(hidden)
        return label_prediction

### Модели

In [387]:
class ClassificationHead(nn.Module):
    def __init__(self, input_size, hidden_size, num_labels=2):
        super(ClassificationHead, self).__init__()
        
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_labels)
        
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        return self.fc3(x)        

In [392]:
class TeacherModel(nn.Module):
    def __init__(self, teacher, head_hidden_size=128):
        super(TeacherModel, self).__init__()
        self.teacher = teacher
        hidden_size = self.teacher.config.hidden_size
        self.classification_head = ClassificationHead(hidden_size, head_hidden_size, 2)
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, inp):
        inds = inp[0]
        labels = inp[1]
        masks = inp[2]
        labels = torch.tensor(labels, dtype=torch.long)
        output = self.teacher(inds, attention_mask=masks)[0]
        output = output[:, 0, :]
        prediction = self.classification_head(output)
        loss = self.loss(prediction, labels)
        return loss
    
    def inference(self, inp):
        inds = inp[0]
        masks = inp[2]
        output = self.teacher(inds, attention_mask=masks)[0]
        output = output[:, 0, :]
        prediction = self.classification_head(output)
        return prediction
    
    def parameters(self):
        return self.classification_head.parameters()

In [556]:
class SimpleModel(nn.Module):
    def __init__(self, bilstm):
        super(SimpleModel, self).__init__()
        self.bilstm = bilstm
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, inp):
        inds = inp[0]
        labels = inp[1]
        labels = torch.tensor(labels, dtype=torch.long)
        prediction = self.bilstm(inds)
        loss = self.loss(prediction, labels)
        return loss
    
    def inference(self, inp):
        inds = inp[0]
        return self.bilstm(inds)

In [640]:
class DistillModel(nn.Module):
    def __init__(self, student, alpha=0.5):
        super(DistillModel, self).__init__()
        self.student = student
        self.loss = DistillLoss(alpha)
    
    def forward(self, inp):
        inds = inp[0]
        labels = inp[1]
        teacher_output = inp[2]
        labels = torch.tensor(labels, dtype=torch.long)
        label_prediction = self.student(inds)
        loss = self.loss(label_prediction, labels, label_prediction, teacher_output)
        return loss
    
    def inference(self, inp):
        inds = inp[0]
        return self.student(inds)

In [427]:
def get_teacher_output(teacher, dataset):
    dataloader = DataLoader(dataset, 20, shuffle=False)
    teacher_output = []
    for info in dataloader: # прогоняю батчами, потому что кернел падает, если считать все сразу
        result = teacher.inference(info).detach()
        teacher_output.append(result)
    teacher_output = torch.cat(teacher_output)
    return teacher_output

### Обучение моделей

In [558]:
def train(model, dataset, epochs=5, batch_size=64):
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    epoch_loss = []
    optimizer = optim.Adam(model.parameters())    
    print('training started...')
    model.train()
    for e in range(epochs):
        losses = 0
        count = 0
        print(f'epoch {e} loss:', end=' ')
        for info in dataloader:
            loss = model(info)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses += loss
            count += 1
        losses /= count
        print(losses.item())
        epoch_loss.append(losses)
    print('training finished...')
    return epoch_loss

### Функция для измерения точности модели

In [381]:
def measure_accuracy(model, dataset):
    dataloader = DataLoader(dataset, 1, shuffle=True)
    correct = 0
    count = 0
    positive = 0
    true_positive = 0
    for info in dataloader:
        prediction = torch.argmax(model.inference(info))
        correct += int(prediction.item() == int(info[1]))
        count += 1
        positive += int(int(prediction.item()) and int(info[1]))
        true_positive += int(info[1])
    print(positive / true_positive)
    return correct / count

#### Прочитаем данные и поделим на трейн и тест

In [563]:
X, y, maxlen = read_and_preprocess_spam_data('/data/dataset.txt')
field = get_vocab(X)
vocab_size = len(field.vocab.stoi.keys())
# ham_data = [x for x in zip(X, y) if x[1] == 0]
# spam_data = [x for x in zip(X, y) if x[1] == 1]
# spam_count = len(spam_data)
# ham_data_subset = random.sample(ham_data, spam_count)
# sms_data = ham_data_subset + spam_data
# X, y = list(zip(*sms_data))

In [570]:
X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, test_size=0.25)
print(len(X_train))
print(len(X_test))

4180
1394


Создадим отдельные датасеты для обучения модели без учителя и для обучения берта

In [571]:
simple_dataset_train = to_dataset(X_train, y_train, field)
teacher_dataset_train = to_dataset_for_bert(X_train, y_train, tokenizer)

In [572]:
simple_dataset_test = to_dataset(X_test, y_test, field)
teacher_dataset_test = to_dataset_for_bert(X_test, y_test, tokenizer)

#### Обучим Берт

In [573]:
teacher = TeacherModel(bert)

In [574]:
train(teacher, teacher_dataset_train, epochs=8, batch_size=50)

training started...
epoch 0 loss: 0.332807332277298
epoch 1 loss: 0.2568117678165436
epoch 2 loss: 0.2305523306131363
epoch 3 loss: 0.2213379293680191
epoch 4 loss: 0.2192654311656952
epoch 5 loss: 0.21612565219402313
epoch 6 loss: 0.21643000841140747
epoch 7 loss: 0.2127368301153183
training finished...


[tensor(0.3328, grad_fn=<DivBackward0>),
 tensor(0.2568, grad_fn=<DivBackward0>),
 tensor(0.2306, grad_fn=<DivBackward0>),
 tensor(0.2213, grad_fn=<DivBackward0>),
 tensor(0.2193, grad_fn=<DivBackward0>),
 tensor(0.2161, grad_fn=<DivBackward0>),
 tensor(0.2164, grad_fn=<DivBackward0>),
 tensor(0.2127, grad_fn=<DivBackward0>)]

In [575]:
print('teacher accuracy:', measure_accuracy(teacher, teacher_dataset_test))

0.6042780748663101
teacher accuracy: 0.926829268292683


#### Обучим обычный BiLSTM без учителя

In [None]:
model = BiLSTM(input_dim=vocab_size, 
               embedding_dim=16,
               hidden_dim=16, 
               output_dim=2,
               bidirectional=True,
               dropout=0.5,
               num_layers=1,
            )

In [646]:
simple_model = SimpleModel(model)

In [647]:
train(simple_model, simple_dataset_train, epochs=8, batch_size=50)

training started...
epoch 0 loss: 0.5213630199432373
epoch 1 loss: 0.4039366543292999
epoch 2 loss: 0.3245985209941864
epoch 3 loss: 0.28657665848731995
epoch 4 loss: 0.26641058921813965
epoch 5 loss: 0.2391241192817688
epoch 6 loss: 0.2133490890264511
epoch 7 loss: 0.19731676578521729
training finished...


[tensor(0.5214, grad_fn=<DivBackward0>),
 tensor(0.4039, grad_fn=<DivBackward0>),
 tensor(0.3246, grad_fn=<DivBackward0>),
 tensor(0.2866, grad_fn=<DivBackward0>),
 tensor(0.2664, grad_fn=<DivBackward0>),
 tensor(0.2391, grad_fn=<DivBackward0>),
 tensor(0.2133, grad_fn=<DivBackward0>),
 tensor(0.1973, grad_fn=<DivBackward0>)]

In [648]:
print('simple model accuracy:', measure_accuracy(simple_model, simple_dataset_test))

0.6844919786096256
simple model accuracy: 0.9225251076040172


#### Применим дистилляцию

посчитаем выходы берта для данного датасета и положим их в отдельный датасет

In [626]:
teacher_output_train = get_teacher_output(teacher, teacher_dataset_train)
teacher_output_test = get_teacher_output(teacher, teacher_dataset_test)

0

In [621]:
distill_dataset_train = to_dataset_distill(X_train, y_train, teacher_output_train, field)
distill_dataset_test = to_dataset_distill(X_test, y_test, teacher_output_test, field)

In [649]:
distill_model = BiLSTM(input_dim=vocab_size, 
               embedding_dim=16,
               hidden_dim=16, 
               output_dim=2,
               bidirectional=True,
               dropout=0.8,
               num_layers=1,
            )

In [650]:
student = DistillModel(distill_model, alpha=0.5)

In [651]:
train(student, distill_dataset_train, epochs=8, batch_size=50)

training started...
epoch 0 loss: 1.96905517578125
epoch 1 loss: 1.2974913120269775
epoch 2 loss: 1.088428258895874
epoch 3 loss: 1.043113112449646
epoch 4 loss: 1.0074676275253296
epoch 5 loss: 0.9557787775993347
epoch 6 loss: 0.920454204082489
epoch 7 loss: 0.8890780210494995
training finished...


[tensor(1.9691, grad_fn=<DivBackward0>),
 tensor(1.2975, grad_fn=<DivBackward0>),
 tensor(1.0884, grad_fn=<DivBackward0>),
 tensor(1.0431, grad_fn=<DivBackward0>),
 tensor(1.0075, grad_fn=<DivBackward0>),
 tensor(0.9558, grad_fn=<DivBackward0>),
 tensor(0.9205, grad_fn=<DivBackward0>),
 tensor(0.8891, grad_fn=<DivBackward0>)]

In [652]:
print('student model accuracy:', measure_accuracy(student, distill_dataset_test))

0.026737967914438502
student model accuracy: 0.8687230989956959
