### Teste de distilação de modelo usando BERT e BiLSTM no dataset SST-2.

Neste notebook exploraremos a distilação de conhecimento utilizando o modelo pré-treinado BERT como professor e treinando uma BiLSTM como aluna.

### Carregando Base de Dados

In [1]:
from pathlib import Path
from os.path import join

def load_sentences(data_folder, file):
    path = join(data_folder, file)
    sentences = open(path).readlines()
    sentences = list(map(lambda x: x.strip(), sentences))
    return sentences

data_folder = Path('../../data/STT2/')

train_file = Path('train.txt')
test_file = Path('test.txt')
dev_file = Path('dev.txt')

train_sentences = load_sentences(data_folder, train_file)
test_sentences = load_sentences(data_folder, test_file)
dev_sentences = load_sentences(data_folder, dev_file)

print("Sentenças de Treino: \n", train_sentences[:10])
print("Sentenças de Teste: \n", test_sentences[:10])
print("Sentenças de Dev: \n", dev_sentences[:10])

print("Tamanho de Treino: \n", len(train_sentences))
print("Tamanho de Teste: \n", len(test_sentences))
print("Tamanho de Dev: \n", len(dev_sentences))


Sentenças de Treino: 
 [&quot;The Rock is destined to be the 21st Century &#39;s new `` Conan &#39;&#39; and that he &#39;s going to make a splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven Segal .&quot;, &quot;The gorgeously elaborate continuation of `` The Lord of the Rings &#39;&#39; trilogy is so huge that a column of words can not adequately describe co-writer\\/director Peter Jackson &#39;s expanded vision of J.R.R. Tolkien &#39;s Middle-earth .&quot;, &#39;Singer\\/composer Bryan Adams contributes a slew of songs -- a few potential hits , a few more simply intrusive to the story -- but the whole package certainly captures the intended , er , spirit of the piece .&#39;, &quot;You &#39;d think by now America would have had enough of plucky British eccentrics with hearts of gold .&quot;, &#39;Yet the act is still charming here .&#39;, &quot;Whether or not you &#39;re enlightened by any of Derrida &#39;s lectures on `` the other &#39;&#39; and `` the se

### Carregamento do Modelo

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
bert_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
embedding = bert_model.bert.embeddings

In [None]:
max_length = 25

examples = train_sentences
inputs = tokenizer(examples,
                          return_tensors='pt',
                          padding='max_length',
                          truncation=True,
                          max_length=max_length)

embedding.eval()
bert_model.eval()
with torch.no_grad():
    bert_logits = bert_model(**inputs)[0]
    inputs = embedding(inputs['input_ids'])

### Formato do Modelo

In [3]:
bert_model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

### Novo Modelo BiLSTM

In [13]:
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl

class BiLSTM(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(
            input_size=768,
            hidden_size=150,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.dense = nn.Linear(
            in_features=300,
            out_features=200,
        )
        self.output = nn.Linear(
            in_features=200,
            out_features=2,
        )
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

    def forward(self, x):
        _, (last_state, _) = self.bilstm(x)
        last_state = last_state.view(x.size(0), -1)
        dense_state = F.dropout(nn.functional.relu(self.dense(last_state)), 0.5)
        logits = self.output(dense_state)
        return logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        result = pl.TrainResult(loss)
        result.log('train_loss', loss, prog_bar=True)
        return result

In [14]:
from torch.utils.data import Dataset, DataLoader

class STT2Dataset(Dataset):
    def __init__(self, inputs, bert_logits):
        self.inputs = inputs
        self.bert_logits = bert_logits
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.bert_logits[idx]

### Predição

In [19]:
batch_size = 512

bilstm_model = BiLSTM()
bilstm_model.lr = 0

train_dataset = STT2Dataset(inputs, bert_logits)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=12)

trainer = pl.Trainer(max_epochs=20, accumulate_grad_batches=len(train_dataloader), auto_lr_find=True)
trainer.fit(bilstm_model, train_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name   | Type   | Params
----------------------------------
0 | bilstm | LSTM   | 1 M   
1 | dense  | Linear | 60 K  
2 | output | Linear | 402   


HBox(children=(HTML(value=&#39;Finding best initial lr&#39;), FloatProgress(value=0.0), HTML(value=&#39;&#39;)))

Saving latest checkpoint..
LR finder stopped early due to diverging loss.
Learning rate set to 2.7542287033381663e-07

  | Name   | Type   | Params
----------------------------------
0 | bilstm | LSTM   | 1 M   
1 | dense  | Linear | 60 K  
2 | output | Linear | 402   


HBox(children=(HTML(value=&#39;Training&#39;), FloatProgress(value=1.0, bar_style=&#39;info&#39;, layout=Layout(flex=&#39;2&#39;), max…















































































1