In [9]:
import pickle
import requests

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from sklearn.metrics import classification_report
from transformers import AutoModel, BertTokenizerFast, BertTokenizer

In [None]:
# Download model
!wget https://storage.googleapis.com/postdata-models/stanzas/eval/saved_weights_bert.pt -O bert_data/saved_weights_bert.pt

In [2]:
BERT_NAME = 'dccuchile/bert-base-spanish-wwm-cased'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class BERT_lstm(nn.Module):

    def __init__(self):
        super(BERT_lstm, self).__init__()
        self.bert = AutoModel.from_pretrained(BERT_NAME)
        self.lstm = nn.LSTM(768, 768, batch_first=True, num_layers=3, dropout=0.2, bidirectional=True)
        self.linear1 = nn.Linear(768*2, 300)
        self.linear2 = nn.Linear(300, 46)

    def forward(self, sent_id, mask):
        self.bert.config.return_dict=False
        self.bert.config.output_hidden_states=True
        with torch.no_grad():  # disable gradient calculation to freeze the model
            # pass the inputs to the model
            cls_emb, pool, hidden_states, *_ = self.bert(sent_id, attention_mask=mask)
        output, (h_n, c_n) = self.lstm(cls_emb)
        out = self.linear1(output)
        out = torch.sum(out, 1).squeeze(1)
        out = nn.functional.gelu(out)
        out = self.linear2(out)
        return out

In [None]:
test_seq = pickle.load(open("bert_data/test_seq.p","rb"))
test_mask = pickle.load(open("bert_data/test_mask.p","rb"))
test_y = pickle.load(open("bert_data/test_y.p","rb"))

In [5]:
model = BERT_lstm()
path = 'bert_data/saved_weights_bert.pt'
model.load_state_dict(torch.load(path))
with torch.no_grad():
    preds = model(test_seq, test_mask)
    preds = preds.detach().cpu().numpy()

In [None]:
preds = np.argmax(preds, axis = 1)

In [12]:
print(classification_report(test_y, preds, zero_division=0, digits = 4))

              precision    recall  f1-score   support

           0     0.6000    0.1034    0.1765        29
           1     0.5000    0.1818    0.2667        11
           2     0.7500    0.6429    0.6923        14
           3     0.4737    0.6429    0.5455        14
           4     0.2708    0.4643    0.3421        28
           5     0.0000    0.0000    0.0000         2
           6     0.5000    0.1429    0.2222        28
           7     0.4667    0.2500    0.3256        28
           8     0.4444    0.8889    0.5926        27
           9     0.3636    0.1538    0.2162        26
          10     0.3333    0.0769    0.1250        26
          11     0.1237    1.0000    0.2202        12
          12     0.3478    0.5333    0.4211        15
          13     0.5000    0.7857    0.6111        14
          14     0.2712    0.5926    0.3721        27
          15     0.2727    0.5000    0.3529        12
          16     0.6000    0.2143    0.3158        28
          17     0.6471    