### Teste de distilação de modelo usando BERT e BiLSTM no dataset SST-2.

Neste notebook exploraremos a distilação de conhecimento utilizando o modelo pré-treinado BERT como professor e treinando uma BiLSTM como aluna.

### Carregando Base de Dados

In [1]:
from pathlib import Path
from os.path import join

data_folder = Path('../../data/STT2/')
sentences_file = Path('sentences.txt')
sentences_path = join(data_folder, sentences_file)

sentences = open(sentences_path).readlines()[1:]
sentences = list(map(lambda x: x.strip(), sentences))
sentences = list(map(lambda x: x.split('\t')[1], sentences))

print(sentences[:10])

[&quot;The Rock is destined to be the 21st Century &#39;s new `` Conan &#39;&#39; and that he &#39;s going to make a splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven Segal .&quot;, &quot;The gorgeously elaborate continuation of `` The Lord of the Rings &#39;&#39; trilogy is so huge that a column of words can not adequately describe co-writer\\/director Peter Jackson &#39;s expanded vision of J.R.R. Tolkien &#39;s Middle-earth .&quot;, &#39;Effective but too-tepid biopic&#39;, &#39;If you sometimes like to go to the movies to have fun , Wasabi is a good place to start .&#39;, &quot;Emerges as something rare , an issue movie that &#39;s so honest and keenly observed that it does n&#39;t feel like one .&quot;, &#39;The film provides some great insight into the neurotic mindset of all comics -- even those who have reached the absolute top of the game .&#39;, &#39;Offers that rare combination of entertainment and education .&#39;, &#39;Perhaps no picture ever 

### Carregamento do Modelo

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
bert_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
embedding = bert_model.bert.embeddings

### Formato do Modelo

In [3]:
bert_model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

### Predição

In [4]:
examples = sentences[:20]
inputs = tokenizer(examples, return_tensors='pt', padding=True)

bert_model.eval()

with torch.no_grad():
    bert_logits = bert_model(**inputs)[0]

embedding.eval()

with torch.no_grad():
    input_representations = embedding(inputs['input_ids'])

### Novo Modelo BiLSTM

In [5]:
from torch import nn
from torch.nn import functional as F

class BiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(
            input_size=768,
            hidden_size=150,
            num_layers=1,
            batch_first=True,
            dropout=0.1,
            bidirectional=True,
        )
        self.dense = nn.Linear(
            in_features=300,
            out_features=200,
        )
        self.output = nn.Linear(
            in_features=200,
            out_features=2,
        )
    def forward(self, x):
        _, (last_state, _) = self.bilstm(x)
        last_state = last_state.view(x.size(0), -1)
        dense_state = nn.functional.relu(self.dense(last_state))
        logits = self.output(dense_state)
        return logits

In [6]:
from torch.utils.data import Dataset, DataLoader

class STT2Dataset(Dataset):
    def __init__(self, sentences, bert_logits):
        self.data = sentences
        self.labels = bert_logits
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (self.data[idx], self.labels[idx])


In [7]:
bilstm_model = BiLSTM()
optim = torch.optim.Adam(bilstm_model.parameters())
criterion = torch.nn.MSELoss()

train_dataset = STT2Dataset(input_representations, bert_logits)
train_dataloader = DataLoader(train_dataset, batch_size=5)

for e in range(100):
    for batch in train_dataloader:
        optim.zero_grad()
        x, y = batch
        logits = bilstm_model(x)
        loss = criterion(logits, y)
        print(loss)
        loss.backward()
        optim.step()

tensor(10.8305, grad_fn=&lt;MseLossBackward&gt;)
tensor(8.9409, grad_fn=&lt;MseLossBackward&gt;)
tensor(10.4380, grad_fn=&lt;MseLossBackward&gt;)
tensor(5.4852, grad_fn=&lt;MseLossBackward&gt;)
tensor(6.4239, grad_fn=&lt;MseLossBackward&gt;)
tensor(2.9726, grad_fn=&lt;MseLossBackward&gt;)
tensor(3.2037, grad_fn=&lt;MseLossBackward&gt;)
tensor(2.8697, grad_fn=&lt;MseLossBackward&gt;)
tensor(3.5355, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.3920, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.2983, grad_fn=&lt;MseLossBackward&gt;)
tensor(4.3300, grad_fn=&lt;MseLossBackward&gt;)
tensor(3.9817, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.7778, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.2138, grad_fn=&lt;MseLossBackward&gt;)
tensor(2.2633, grad_fn=&lt;MseLossBackward&gt;)
tensor(2.3566, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.3163, grad_fn=&lt;MseLossBackward&gt;)
tensor(0.9506, grad_fn=&lt;MseLossBackward&gt;)
tensor(1.4314, grad_fn=&lt;MseLossBackward&gt;)
tensor(1.9735, grad_fn=&lt;MseLossBack