In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas import read_parquet
from transformers import BertModel, BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support
from torch.nn.utils.rnn import pad_sequence
import evaluate
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "data/mBERT/fine"

tokenizer = BertTokenizerFast.from_pretrained(model_path)

In [3]:
train_data = read_parquet("data/merge/train.parquet")
dev_data = read_parquet("data/merge/dev.parquet")
test_data = read_parquet("data/merge/test.parquet")

with open("data/merge/tags_2_idx.json", "r") as f:
    tags2idx = json.load(f)

idx2tags = {idx: tag for tag, idx in tags2idx.items()}

with open("data/merge/chars2idx.json", "r") as f:
    chars2idx = json.load(f)

In [4]:
sentences_train = train_data["tokens"].values.tolist()
tags_train = train_data["ner_tags"].values.tolist()

sentences_dev = dev_data["tokens"].values.tolist()
tags_dev = dev_data["ner_tags"].values.tolist()

sentences_test = test_data["tokens"].values.tolist()
tags_test = test_data["ner_tags"].values.tolist()

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = BertModel.from_pretrained("data/mBERT/fine").to(device)
bert.eval()

Some weights of the model checkpoint at data/mBERT/fine were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
         

In [6]:
def pooling_embedding(batch_word_ids, batch_embeddings):
    processed_embeddings = []
    
    for word_ids, embeddings in zip(batch_word_ids, batch_embeddings):
        shrunk_embeddings = []
        current_embeddings = []
        previous_word_idx = None
        for i, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            if word_idx == previous_word_idx:
                current_embeddings.append(embeddings[i])
            else:
                if current_embeddings:
                    shrunk_embeddings.append(
                        torch.mean(torch.stack(current_embeddings), dim=0)
                    )
                    current_embeddings.clear()
                current_embeddings.append(embeddings[i])
                previous_word_idx = word_idx
        if current_embeddings:
            shrunk_embeddings.append(
                torch.mean(torch.stack(current_embeddings), dim=0)
            )
    
        processed_embeddings.append(torch.stack(shrunk_embeddings))

    return processed_embeddings
                    

In [7]:
class MultilingualDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.sentences[idx].tolist(), torch.tensor(self.labels[idx].tolist(), dtype=int), idx
    
def collate_fn(batch):
    sentences, labels, ids = zip(*batch)
    # print([len(x) for x in sentences])
    B = len(labels)
    tokenized_inputs_list = [
        tokenizer(sentence, is_split_into_words=True, truncation=True, return_tensors="pt")
        for sentence in sentences
    ]
    input_ids = [x["input_ids"][0] for x in tokenized_inputs_list]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0).to(device)
    attention_mask = [x["attention_mask"][0] for x in tokenized_inputs_list]
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0).to(device)
    with torch.no_grad():
        bert_output = bert(input_ids, attention_mask)
    bert_embeddings = bert_output["last_hidden_state"]
    
    word_ids = [x.word_ids() for x in tokenized_inputs_list]
    pooled_embeddings = pooling_embedding(word_ids, bert_embeddings)
    batch_embeddings = pad_sequence(pooled_embeddings, batch_first=True, padding_value=0)
    # print(batch_embeddings.shape)
    T = batch_embeddings.shape[1]   # The original length might be larger than T (at most 512)
    labels = [x[:T] for x in labels]
    sentences = [x[:T] for x in sentences]
    batch_labels = pad_sequence(labels, batch_first=True, padding_value=-100)

    L = max(len(word) for sentence in sentences for word in sentence)
    batch_char_ids = torch.zeros(B, T, L, dtype=int)
    for i in range(B):
        for j in range(len(sentences[i])):
            cur_word = sentences[i][j]
            for k in range(len(cur_word)):
                batch_char_ids[i][j][k] = chars2idx.get(cur_word[k], chars2idx["<unk>"])
    try:
        assert batch_embeddings.shape[1] == batch_labels.shape[1], f"batch_embeddings: {batch_embeddings.shape} batch_labels: {batch_labels.shape}"
        assert batch_embeddings.shape[1] == batch_char_ids.shape[1]
    except:
        print(T)
        print(ids)
        print(batch_embeddings.shape)
        print(batch_char_ids.shape)
    return batch_embeddings, batch_labels, batch_char_ids

the next 2 blocks are used for testing

In [8]:
train_dataset = MultilingualDataset(sentences_train, tags_train)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    collate_fn=collate_fn
)
for data in train_loader:
    bert_embeddings, labels, char_ids = data
    print(bert_embeddings.shape, labels.shape, char_ids.shape)
    break

torch.Size([2, 17, 768]) torch.Size([2, 17]) torch.Size([2, 17, 1])


In [9]:
sentences = [train_dataset[2418][0], train_dataset[2419][0]]
labels = [train_dataset[2418][1], train_dataset[2419][1]]
print(f"original labels size: {[len(x) for x in labels]}")
tokenized_inputs_list = [
    tokenizer(sentence, is_split_into_words=True, truncation=True, return_tensors="pt")
    for sentence in sentences
]
input_ids = [x["input_ids"][0] for x in tokenized_inputs_list]
print([len(x) for x in input_ids])
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = [x["attention_mask"][0] for x in tokenized_inputs_list]
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
    bert_output = bert(input_ids, attention_mask)
bert_embeddings = bert_output["last_hidden_state"]
word_ids = [x.word_ids() for x in tokenized_inputs_list]
pooled_embeddings = pooling_embedding(word_ids, bert_embeddings)
batch_embeddings = pad_sequence(pooled_embeddings, batch_first=True, padding_value=0)
# print(batch_embeddings.shape)
B = 2
T = batch_embeddings.shape[1]
labels = [x[:T] for x in labels]
sentences = [x[:T] for x in sentences]
batch_labels = pad_sequence(labels, batch_first=True, padding_value=-100)
L = max(len(word) for sentence in sentences for word in sentence)
batch_char_ids = torch.zeros(B, T, L, dtype=int)
for i in range(B):
    for j in range(len(sentences[i])):
        cur_word = sentences[i][j]
        for k in range(len(cur_word)):
            try: 
                batch_char_ids[i][j][k] = chars2idx.get(cur_word[k], chars2idx["<unk>"])
            except:
                print(i, j, k)

original labels size: [9, 3]
[11, 5]


# BiLSTM model training and evaluation

In [10]:
# class BiLSTM(nn.Module):
#     def __init__(
#             self, input_size,
#             lstm_hidden_dim, lstm_num_layers, lstm_dropout, 
#             linear_output_dim, label_size
#         ):
#         super().__init__()
#         self.lstm = nn.LSTM(
#             input_size=input_size,
#             hidden_size=lstm_hidden_dim,
#             num_layers=lstm_num_layers,
#             batch_first=True,
#             dropout=lstm_dropout,
#             bidirectional=True
#         )
#         self.fc1 = nn.Linear(2 * lstm_hidden_dim, linear_output_dim)
#         self.dropout = nn.Dropout(lstm_dropout)
#         self.fc2 = nn.Linear(linear_output_dim, label_size)

#     def forward(self, x):
#         x, _ = self.lstm(x)
#         x = self.dropout(F.elu(self.fc1(x), inplace=True))
#         x = self.fc2(x)
#         return x


class BiLSTMNERwithCNN(nn.Module):
    def __init__(self, target_size=8, embedding_dim=100, lstm_hidden_dim=256, lstm_layers=1, lstm_dropout=0.33, linear_dim=128):
        super().__init__()        
        self.dropout = nn.Dropout(0.33)
#         self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.fc = nn.Linear(embedding_dim, 30)
        self.lstm = nn.LSTM(798, lstm_hidden_dim,num_layers=lstm_layers,batch_first=True, bidirectional=True) #dropout=lstm_dropout)
        self.linear = nn.Linear(lstm_hidden_dim * 2, linear_dim)
        self.elu = nn.ELU()
        self.classifier = nn.Linear(linear_dim, target_size)
        self.conv1d = nn.Conv1d(in_channels=768, out_channels=30, kernel_size=3, padding=1)
        self.maxpool = nn.AdaptiveMaxPool1d(output_size=1)
        
    def forward(self, sentence):
#         embeds = self.dropout(self.embedding(sentence))
        batch_size, sequence_length, embedding_dim = sentence.size()
        
        cnn_out = self.conv1d(sentence.permute(0, 2, 1))
        pool_out = self.maxpool(cnn_out)
        embeds_2 = pool_out.permute(0, 2, 1)
        embeds_2 = embeds_2.repeat(1, sequence_length, 1)
        concatenated_tensor = torch.cat((sentence, embeds_2), dim=2)
        lstm_out, _ = self.lstm(concatenated_tensor)
        lstm_out = self.dropout(lstm_out)
        linear_out = self.linear(lstm_out)
        elu_out = self.elu(linear_out)
        output = self.classifier(elu_out)
        return output

In [11]:
train_dataset = MultilingualDataset(sentences_train, tags_train)
dev_dataset = MultilingualDataset(sentences_dev, tags_dev)
test_dataset = MultilingualDataset(sentences_test, tags_test)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn
)
# lstm_model = BiLSTM(
#     input_size=768,
#     lstm_hidden_dim=256, lstm_num_layers=2, lstm_dropout=0.33,
#     linear_output_dim=128, label_size=len(idx2tags)
# )
lstm_model = BiLSTMNERwithCNN()

lstm_model = lstm_model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = Adam(lstm_model.parameters(), lr=1e-3)
epochs = 10

In [13]:
best_loss = torch.inf
best_acc = 0
for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}/{epochs}')
    train_loss = 0
    cur_total = 0
    correct = 0
    amount = 0
    lstm_model.train()
    train_loop = tqdm(train_loader, desc=f'Epoch: {epoch + 1}/{epochs}')
    for bert_embeddings, labels, char_ids in train_loop:
        B = bert_embeddings.shape[0]
        bert_embeddings = bert_embeddings.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = lstm_model(bert_embeddings)
        logits_reshape = logits.view(-1, logits.shape[-1])
        labels_reshape = labels.view(-1)
        loss = criterion(logits_reshape, labels_reshape)
        loss.backward()
        optimizer.step()

        train_loss += loss.cpu().item() * B

        mask = labels_reshape != -100
        
        logits_non_pad = logits_reshape[mask]
        labels_non_pad = labels_reshape[mask]

        _, predictions = torch.max(logits_non_pad, 1)
        correct += sum(predictions == labels_non_pad)
        amount += len(labels_non_pad)

        cur_total += B
        running_loss = train_loss / cur_total
        running_acc = correct / amount
        train_loop.set_postfix(loss=running_loss, acc=running_acc.item())

    train_loss /= len(train_dataset)
    train_acc = correct / amount

    dev_loss = 0
    y_true, y_pred = [], []
    correct = 0
    amount = 0
    lstm_model.eval()
    dev_loop = tqdm(dev_loader, desc=f'Epoch: {epoch + 1}/{epochs}')
    with torch.no_grad():
        for bert_embeddings, labels, char_ids in dev_loop:
            B = bert_embeddings.shape[0]
            bert_embeddings = bert_embeddings.to(device)
            labels = labels.to(device)

            logits = lstm_model(bert_embeddings)
            logits_reshape = logits.view(-1, logits.shape[-1])
            labels_reshape = labels.view(-1)
            loss = criterion(logits_reshape, labels_reshape)
            dev_loss += loss.cpu().item() * B
            
            mask = labels_reshape != -100
            logits_non_pad = logits_reshape[mask]
            labels_non_pad = labels_reshape[mask]

            _, predictions = torch.max(logits_non_pad, 1)
            correct += sum(predictions == labels_non_pad)
            amount += len(labels_non_pad)
            y_pred.extend(predictions.cpu())
            y_true.extend(labels_non_pad.cpu())

    dev_loss /= len(dev_dataset)
    val_acc = correct / amount
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    print('train_loss: {:.4f}, train_acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'.format(train_loss, train_acc, dev_loss, val_acc))
    print('val_precision: {:.4f}, val_recall: {:.4f}, val_f1: {:.4f}'.format(val_precision, val_recall, val_f1))

#     if dev_loss < best_loss:
#         print(f"update dev_loss from {best_loss} -> {dev_loss}")
#         best_loss = dev_loss
#         torch.save(lstm_model.state_dict(), "best_channel.pt")
    if val_acc > best_acc:
#         print(f"update dev_loss from {best_loss} -> {dev_loss}")
        best_acc = val_acc
        torch.save(lstm_model.state_dict(), "best_channel.pt")

Epoch: 1/10


Epoch: 1/10: 100%|██████████████████████████████████████████| 5007/5007 [07:15<00:00, 11.50it/s, acc=0.897, loss=0.304]
Epoch: 1/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:45<00:00, 15.10it/s]


train_loss: 0.3043, train_acc: 0.8965, val_loss: 0.2404, val_acc: 0.8993
val_precision: 0.8568, val_recall: 0.8307, val_f1: 0.8416
Epoch: 2/10


Epoch: 2/10: 100%|██████████████████████████████████████████| 5007/5007 [07:25<00:00, 11.24it/s, acc=0.906, loss=0.273]
Epoch: 2/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:47<00:00, 14.99it/s]


train_loss: 0.2728, train_acc: 0.9064, val_loss: 0.2282, val_acc: 0.9047
val_precision: 0.8751, val_recall: 0.8282, val_f1: 0.8499
Epoch: 3/10


Epoch: 3/10: 100%|██████████████████████████████████████████| 5007/5007 [07:18<00:00, 11.41it/s, acc=0.914, loss=0.248]
Epoch: 3/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:46<00:00, 15.07it/s]


train_loss: 0.2483, train_acc: 0.9141, val_loss: 0.2232, val_acc: 0.9088
val_precision: 0.8698, val_recall: 0.8421, val_f1: 0.8547
Epoch: 4/10


Epoch: 4/10: 100%|███████████████████████████████████████████| 5007/5007 [07:22<00:00, 11.33it/s, acc=0.92, loss=0.229]
Epoch: 4/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:45<00:00, 15.16it/s]


train_loss: 0.2289, train_acc: 0.9198, val_loss: 0.2252, val_acc: 0.9077
val_precision: 0.8752, val_recall: 0.8363, val_f1: 0.8537
Epoch: 5/10


Epoch: 5/10: 100%|███████████████████████████████████████████| 5007/5007 [07:24<00:00, 11.25it/s, acc=0.926, loss=0.21]
Epoch: 5/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:41<00:00, 15.49it/s]


train_loss: 0.2098, train_acc: 0.9265, val_loss: 0.2188, val_acc: 0.9107
val_precision: 0.8763, val_recall: 0.8433, val_f1: 0.8587
Epoch: 6/10


Epoch: 6/10: 100%|██████████████████████████████████████████| 5007/5007 [07:19<00:00, 11.40it/s, acc=0.931, loss=0.196]
Epoch: 6/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:47<00:00, 14.93it/s]


train_loss: 0.1956, train_acc: 0.9307, val_loss: 0.2376, val_acc: 0.9121
val_precision: 0.8713, val_recall: 0.8522, val_f1: 0.8614
Epoch: 7/10


Epoch: 7/10: 100%|██████████████████████████████████████████| 5007/5007 [07:19<00:00, 11.39it/s, acc=0.936, loss=0.182]
Epoch: 7/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:43<00:00, 15.33it/s]


train_loss: 0.1817, train_acc: 0.9356, val_loss: 0.2305, val_acc: 0.9098
val_precision: 0.8636, val_recall: 0.8594, val_f1: 0.8613
Epoch: 8/10


Epoch: 8/10: 100%|██████████████████████████████████████████| 5007/5007 [07:16<00:00, 11.47it/s, acc=0.939, loss=0.171]
Epoch: 8/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:40<00:00, 15.62it/s]


train_loss: 0.1712, train_acc: 0.9393, val_loss: 0.2506, val_acc: 0.9094
val_precision: 0.8708, val_recall: 0.8500, val_f1: 0.8598
Epoch: 9/10


Epoch: 9/10: 100%|██████████████████████████████████████████| 5007/5007 [07:16<00:00, 11.46it/s, acc=0.943, loss=0.161]
Epoch: 9/10: 100%|█████████████████████████████████████████████████████████████████| 2507/2507 [02:48<00:00, 14.87it/s]


train_loss: 0.1609, train_acc: 0.9429, val_loss: 0.2566, val_acc: 0.9102
val_precision: 0.8629, val_recall: 0.8575, val_f1: 0.8601
Epoch: 10/10


Epoch: 10/10: 100%|█████████████████████████████████████████| 5007/5007 [07:13<00:00, 11.55it/s, acc=0.945, loss=0.153]
Epoch: 10/10: 100%|████████████████████████████████████████████████████████████████| 2507/2507 [02:41<00:00, 15.49it/s]


train_loss: 0.1532, train_acc: 0.9448, val_loss: 0.2552, val_acc: 0.9097
val_precision: 0.8731, val_recall: 0.8423, val_f1: 0.8571


In [12]:
state_dict = torch.load("best_channel.pt")
lstm_model.load_state_dict(state_dict)

<All keys matched successfully>

## Evaluation

In [14]:
seqeval = evaluate.load("seqeval")

In [15]:
train_predictions = []
train_references = []
lstm_model.eval()
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,   # use batch_size=1 => no padding needed
    shuffle=False,
    collate_fn=collate_fn
)
for bert_embeddings, labels, char_ids in tqdm(train_loader):
    bert_embeddings = bert_embeddings.to(device)
    with torch.no_grad():
        logits = lstm_model(bert_embeddings)
    logit = logits[0].cpu()
    label = labels[0]
    _, prediction = torch.max(logit, dim=1)
    assert len(prediction) == len(label)
    prediction_tag, reference_tag = [], []
    for p_idx, r_idx in zip(prediction, label):
        prediction_tag.append(idx2tags[p_idx.item()])
        reference_tag.append(idx2tags[r_idx.item()])
    train_predictions.append(prediction_tag)
    train_references.append(reference_tag)

100%|████████████████████████████████████████████████████████████████████████████| 80100/80100 [23:51<00:00, 55.94it/s]


In [16]:
train_result = seqeval.compute(predictions=train_predictions, references=train_references)
train_result

{'LOC': {'precision': 0.7070236009481604,
  'recall': 0.7144233272585264,
  'f1': 0.7107042034653337,
  'number': 38410},
 'ORG': {'precision': 0.5663513351694408,
  'recall': 0.6927786729163615,
  'f1': 0.6232178152589275,
  'number': 34135},
 'PER': {'precision': 0.756505379193386,
  'recall': 0.8176949074869679,
  'f1': 0.7859109177999228,
  'number': 34914},
 'overall_precision': 0.6731583618612907,
 'overall_recall': 0.7411012572236854,
 'overall_f1': 0.7054977764391134,
 'overall_accuracy': 0.9260698160482358}

In [17]:
dev_predictions = []
dev_references = []
lstm_model.eval()
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=1,   # use batch_size=1 => no padding needed
    shuffle=False,
    collate_fn=collate_fn
)
for bert_embeddings, labels, char_ids in tqdm(dev_loader):
    bert_embeddings = bert_embeddings.to(device)
    with torch.no_grad():
        logits = lstm_model(bert_embeddings)
    logit = logits[0].cpu()
    label = labels[0]
    _, prediction = torch.max(logit, dim=1)
    assert len(prediction) == len(label)
    prediction_tag, reference_tag = [], []
    for p_idx, r_idx in zip(prediction, label):
        prediction_tag.append(idx2tags[p_idx.item()])
        reference_tag.append(idx2tags[r_idx.item()])
    dev_predictions.append(prediction_tag)
    dev_references.append(reference_tag)

100%|████████████████████████████████████████████████████████████████████████████| 40100/40100 [12:09<00:00, 54.94it/s]


In [18]:
dev_result = seqeval.compute(predictions=dev_predictions, references=dev_references)
dev_result

{'LOC': {'precision': 0.6830333649089569,
  'recall': 0.6733942098163329,
  'f1': 0.6781795380917547,
  'number': 19274},
 'ORG': {'precision': 0.5298558347944672,
  'recall': 0.6765547263681592,
  'f1': 0.5942860264394187,
  'number': 16080},
 'PER': {'precision': 0.7335970613167562,
  'recall': 0.7841135608577469,
  'f1': 0.7580145985401459,
  'number': 16555},
 'overall_precision': 0.6437121040032151,
 'overall_recall': 0.7096842551388006,
 'overall_f1': 0.6750902527075812,
 'overall_accuracy': 0.8930089503400488}

In [19]:
test_predictions = []
test_references = []
lstm_model.eval()
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1,   # use batch_size=1 => no padding needed
    shuffle=False,
    collate_fn=collate_fn
)
for bert_embeddings, labels, char_ids in tqdm(test_loader):
    bert_embeddings = bert_embeddings.to(device)
    with torch.no_grad():
        logits = lstm_model(bert_embeddings)
    logit = logits[0].cpu()
    label = labels[0]
    _, prediction = torch.max(logit, dim=1)
    assert len(prediction) == len(label)
    prediction_tag, reference_tag = [], []
    for p_idx, r_idx in zip(prediction, label):
        prediction_tag.append(idx2tags[p_idx.item()])
        reference_tag.append(idx2tags[r_idx.item()])
    test_predictions.append(prediction_tag)
    test_references.append(reference_tag)

100%|████████████████████████████████████████████████████████████████████████████| 40100/40100 [12:32<00:00, 53.27it/s]


In [20]:
test_result = seqeval.compute(predictions=test_predictions, references=test_references)
test_result

{'LOC': {'precision': 0.6720250629140774,
  'recall': 0.6686596146967142,
  'f1': 0.6703381147540984,
  'number': 19569},
 'ORG': {'precision': 0.5002611461943878,
  'recall': 0.6233581824636137,
  'f1': 0.5550667755446094,
  'number': 16902},
 'PER': {'precision': 0.7081172610858292,
  'recall': 0.7780232558139535,
  'f1': 0.7414261177904592,
  'number': 17200},
 'overall_precision': 0.6226316675079926,
 'overall_recall': 0.6894412252426823,
 'overall_f1': 0.6543355054331969,
 'overall_accuracy': 0.8909087079646696}

In [19]:
train_tokens = []
assert len(train_dataset) == len(train_predictions)
for i in range(len(train_dataset)):
    token = train_dataset[i][0]
    prediction = train_predictions[i]
    assert len(token) >= len(prediction)
    if len(token) != len(prediction):
        token = token[:len(prediction)]
    train_tokens.append(token)

dev_tokens = []
assert len(dev_dataset) == len(dev_predictions)
for i in range(len(dev_dataset)):
    token = dev_dataset[i][0]
    prediction = dev_predictions[i]
    assert len(token) >= len(prediction)
    if len(token) != len(prediction):
        token = token[:len(prediction)]
    dev_tokens.append(token)

test_tokens = []
assert len(test_dataset) == len(test_predictions)
for i in range(len(test_dataset)):
    token = test_dataset[i][0]
    prediction = test_predictions[i]
    assert len(token) >= len(prediction)
    if len(token) != len(prediction):
        token = token[:len(prediction)]
    test_tokens.append(token)

train_df = pd.DataFrame({
    "tokens": train_tokens,
    "predictions": train_predictions
})
dev_df = pd.DataFrame({
    "tokens": dev_tokens,
    "predictions": dev_predictions
})
test_df = pd.DataFrame({
    "tokens": test_tokens,
    "predictions": test_predictions
})

train_df.to_csv("Channel_train.csv")
dev_df.to_csv("Channel_dev.csv")
test_df.to_csv("Channel_test.csv")