### Teste de distilação de modelo usando BERT e BiLSTM no dataset SST-2.

Neste notebook exploraremos a distilação de conhecimento utilizando o modelo pré-treinado BERT como professor e treinando uma BiLSTM como aluna.

### Carregando Base de Dados

In [1]:
from pathlib import Path
from os.path import join

def load_sentences(data_folder, file):
    path = join(data_folder, file)
    sentences = open(path).readlines()
    sentences = list(map(lambda x: x.strip(), sentences))
    return sentences

data_folder = Path('../../data/STT2/')

train_file = Path('train.txt')
test_file = Path('test.txt')
dev_file = Path('dev.txt')

train_sentences = load_sentences(data_folder, train_file)
test_sentences = load_sentences(data_folder, test_file)
dev_sentences = load_sentences(data_folder, dev_file)

print("Sentenças de Treino: \n", train_sentences[:10])
print("Sentenças de Teste: \n", test_sentences[:10])
print("Sentenças de Dev: \n", dev_sentences[:10])

Sentenças de Treino: 
 ["The Rock is destined to be the 21st Century 's new `` Conan '' and that he 's going to make a splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven Segal .", "The gorgeously elaborate continuation of `` The Lord of the Rings '' trilogy is so huge that a column of words can not adequately describe co-writer\\/director Peter Jackson 's expanded vision of J.R.R. Tolkien 's Middle-earth .", 'Singer\\/composer Bryan Adams contributes a slew of songs -- a few potential hits , a few more simply intrusive to the story -- but the whole package certainly captures the intended , er , spirit of the piece .', "You 'd think by now America would have had enough of plucky British eccentrics with hearts of gold .", 'Yet the act is still charming here .', "Whether or not you 're enlightened by any of Derrida 's lectures on `` the other '' and `` the self , '' Derrida is an undeniably fascinating and playful fellow .", 'Just the labour involved in creati

### Carregamento do Modelo

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
bert_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
embedding = bert_model.bert.embeddings

### Formato do Modelo

In [3]:
bert_model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

### Novo Modelo BiLSTM

In [4]:
from torch import nn
from torch.nn import functional as F

class BiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(
            input_size=768,
            hidden_size=150,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.dense = nn.Linear(
            in_features=300,
            out_features=200,
        )
        self.output = nn.Linear(
            in_features=200,
            out_features=2,
        )
    def forward(self, x):
        _, (last_state, _) = self.bilstm(x)
        last_state = last_state.view(x.size(0), -1)
        dense_state = nn.functional.relu(self.dense(last_state))
        logits = self.output(dense_state)
        return logits

In [5]:
from torch.utils.data import Dataset, DataLoader

class STT2Dataset(Dataset):
    def __init__(self, sentences, tokenizer, embedding, bert, max_length=50):
        self.data = sentences
        self.tokenizer = tokenizer
        self.bert = bert     
        self.max_length = max_length
        self.embedding = embedding
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        
        inputs = tokenizer(self.data[idx],
                          return_tensors='pt',
                          padding='max_length',
                          truncation=True,
                          max_length=self.max_length)
        
        self.embedding.eval()
        self.bert.eval()
        with torch.no_grad():
            bert_logits = self.bert(**inputs)[0]
            inputs = self.embedding(inputs['input_ids'])
        
        return inputs.squeeze(0), bert_logits.squeeze(0)

### Predição

In [6]:
bilstm_model = BiLSTM()
optim = torch.optim.Adam(bilstm_model.parameters())
criterion = torch.nn.MSELoss()

examples = train_sentences[:500]
train_dataset = STT2Dataset(examples, tokenizer, embedding, bert_model)
train_dataloader = DataLoader(train_dataset, batch_size=64, num_workers=8)

for e in range(5):
    e_loss = 0
    for batch in train_dataloader:
        optim.zero_grad()
        x, y = batch
        logits = bilstm_model(x)
        loss = criterion(logits, y)
        e_loss += loss.data
        loss.backward()
        optim.step()
    e_loss = e_loss / len(train_dataloader)
    print("Epoch: ", e, "Loss:", e_loss)

Epoch:  0 Loss: tensor(7.8126)
Epoch:  1 Loss: tensor(3.8722)
Epoch:  2 Loss: tensor(3.7157)
Epoch:  3 Loss: tensor(3.5315)
Epoch:  4 Loss: tensor(3.4798)
