In [1]:
from pprint import pprint
print = pprint

### Teste de distilação de modelo usando BERT e BiLSTM no dataset SST-2.

Neste notebook exploraremos a distilação de conhecimento utilizando o modelo pré-treinado BERT como professor e treinando uma BiLSTM como aluna.

### Carregando Base de Dados

In [2]:
from pathlib import Path
from os.path import join

data_folder = Path('../../data/STT2/')
sentences_file = Path('sentences.txt')
sentences_path = join(data_folder, sentences_file)

sentences = open(sentences_path).readlines()[1:]
sentences = list(map(lambda x: x.strip(), sentences))
sentences = list(map(lambda x: x.split('\t')[1], sentences))

print(sentences[:10])

[&quot;The Rock is destined to be the 21st Century &#39;s new `` Conan &#39;&#39; and that he &quot;
 &quot;&#39;s going to make a splash even greater than Arnold Schwarzenegger , &quot;
 &#39;Jean-Claud Van Damme or Steven Segal .&#39;,
 &quot;The gorgeously elaborate continuation of `` The Lord of the Rings &#39;&#39; trilogy &quot;
 &#39;is so huge that a column of words can not adequately describe &#39;
 &quot;co-writer\\/director Peter Jackson &#39;s expanded vision of J.R.R. Tolkien &#39;s &quot;
 &#39;Middle-earth .&#39;,
 &#39;Effective but too-tepid biopic&#39;,
 &#39;If you sometimes like to go to the movies to have fun , Wasabi is a good &#39;
 &#39;place to start .&#39;,
 &quot;Emerges as something rare , an issue movie that &#39;s so honest and keenly &quot;
 &quot;observed that it does n&#39;t feel like one .&quot;,
 &#39;The film provides some great insight into the neurotic mindset of all comics &#39;
 &#39;-- even those who have reached the absolute top of the game .&#

### Carregamento do Modelo

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")

### Formato do Modelo

In [4]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

### Predição

In [5]:
examples = sentences[:4]

inputs = tokenizer(examples, return_tensors='pt', padding=True)
logits = model.forward(**inputs)[0]
proba = torch.softmax(logits, dim=-1)
labels = torch.argmax(proba, dim=-1)

In [6]:
print(inputs)
print(logits)
print(proba)
print(labels)

{&#39;attention_mask&#39;: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]),
 &#39;input_ids&#39;: tensor([[  101,  1996,  2600,  2003, 16036,  2000,  2022,  1996,  7398,  2301,
          1005,  1055,  2047,  1036,  1036, 16608,  1005,  1005,  1998,  2008,
       

### Novo Modelo BiLSTM

In [31]:
from torch import nn
from torch import functional as F
import pytorch_lightning as pl

class BiLSTM(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(
            input_size=10,
            hidden_size=5,
            num_layers=1,
            batch_first=True,
            dropout=0.1,
            bidirectional=True
        )
    def forward(self, x):
        logits = self.bilstm(x)
        return logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        logits = self.bilstm(x)
        loss = F.mse_loss(logits, y)
        return loss
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

In [32]:
model = BiLSTM()

In [33]:
t = torch.Tensor([1,1,1,1,1,1,1,1,1,1]).view(1, 1, -1)
print(t)
print(t.shape)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
torch.Size([1, 1, 10])


In [34]:
output = model(t)

In [37]:
print(output[-1])

(tensor([[[-0.0698,  0.0776, -0.2030,  0.2142,  0.1088]],

        [[ 0.1284, -0.0920, -0.1087,  0.1605, -0.2038]]],
       grad_fn=&lt;StackBackward&gt;),
 tensor([[[-0.1365,  0.2202, -0.3863,  0.3490,  0.4019]],

        [[ 0.3225, -0.1806, -0.2759,  0.3992, -0.3076]]],
       grad_fn=&lt;StackBackward&gt;))


In [36]:
print(output[0].shape)

torch.Size([1, 1, 10])
