In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas import read_parquet
from transformers import BertModel, BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support
from torch.nn.utils.rnn import pad_sequence
import evaluate
import pandas as pd
from TorchCRF import CRF

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "data/mBERT/fine"

tokenizer = BertTokenizerFast.from_pretrained(model_path)

In [3]:
train_data = read_parquet("data/merge/train.parquet")
dev_data = read_parquet("data/merge/dev.parquet")
test_data = read_parquet("data/merge/test.parquet")

with open("data/merge/tags_2_idx.json", "r") as f:
    tags2idx = json.load(f)

idx2tags = {idx: tag for tag, idx in tags2idx.items()}

with open("data/merge/chars2idx.json", "r") as f:
    chars2idx = json.load(f)

In [4]:
sentences_train = train_data["tokens"].values.tolist()
tags_train = train_data["ner_tags"].values.tolist()

sentences_dev = dev_data["tokens"].values.tolist()
tags_dev = dev_data["ner_tags"].values.tolist()

sentences_test = test_data["tokens"].values.tolist()
tags_test = test_data["ner_tags"].values.tolist()

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = BertModel.from_pretrained("data/mBERT/fine").to(device)
bert.eval()

  return self.fget.__get__(instance, owner)()


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

In [7]:
def pooling_embedding(batch_word_ids, batch_embeddings):
    processed_embeddings = []
    
    for word_ids, embeddings in zip(batch_word_ids, batch_embeddings):
        shrunk_embeddings = []
        current_embeddings = []
        previous_word_idx = None
        for i, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            if word_idx == previous_word_idx:
                current_embeddings.append(embeddings[i])
            else:
                if current_embeddings:
                    shrunk_embeddings.append(
                        torch.mean(torch.stack(current_embeddings), dim=0)
                    )
                    current_embeddings.clear()
                current_embeddings.append(embeddings[i])
                previous_word_idx = word_idx
        if current_embeddings:
            shrunk_embeddings.append(
                torch.mean(torch.stack(current_embeddings), dim=0)
            )
    
        processed_embeddings.append(torch.stack(shrunk_embeddings))

    return processed_embeddings
                    

In [8]:
class MultilingualDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.sentences[idx].tolist(), torch.tensor(self.labels[idx].tolist(), dtype=int), idx
    
def collate_fn(batch):
    sentences, labels, ids = zip(*batch)
    # print([len(x) for x in sentences])
    B = len(labels)
    tokenized_inputs_list = [
        tokenizer(sentence, is_split_into_words=True, truncation=True, return_tensors="pt")
        for sentence in sentences
    ]
    input_ids = [x["input_ids"][0] for x in tokenized_inputs_list]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0).to(device)
    attention_mask = [x["attention_mask"][0] for x in tokenized_inputs_list]
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0).to(device)
    with torch.no_grad():
        bert_output = bert(input_ids, attention_mask)
    bert_embeddings = bert_output["last_hidden_state"]
    
    word_ids = [x.word_ids() for x in tokenized_inputs_list]
    pooled_embeddings = pooling_embedding(word_ids, bert_embeddings)
    batch_embeddings = pad_sequence(pooled_embeddings, batch_first=True, padding_value=0)
    # print(batch_embeddings.shape)
    T = batch_embeddings.shape[1]   # The original length might be larger than T (at most 512)
    labels = [x[:T] for x in labels]
    sentences = [x[:T] for x in sentences]
    batch_labels = pad_sequence(labels, batch_first=True, padding_value=-100)

    L = max(len(word) for sentence in sentences for word in sentence)
    batch_char_ids = torch.zeros(B, T, L, dtype=int)
    for i in range(B):
        for j in range(len(sentences[i])):
            cur_word = sentences[i][j]
            for k in range(len(cur_word)):
                batch_char_ids[i][j][k] = chars2idx.get(cur_word[k], chars2idx["<unk>"])
    try:
        assert batch_embeddings.shape[1] == batch_labels.shape[1], f"batch_embeddings: {batch_embeddings.shape} batch_labels: {batch_labels.shape}"
        assert batch_embeddings.shape[1] == batch_char_ids.shape[1]
    except:
        print(T)
        print(ids)
        print(batch_embeddings.shape)
        print(batch_char_ids.shape)
    return batch_embeddings, batch_labels, batch_char_ids

the next 2 blocks are used for testing

In [8]:
train_dataset = MultilingualDataset(sentences_train, tags_train)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    collate_fn=collate_fn
)
for data in train_loader:
    bert_embeddings, labels, char_ids = data
    print(bert_embeddings.shape, labels.shape, char_ids.shape)
    break

torch.Size([2, 17, 768]) torch.Size([2, 17]) torch.Size([2, 17, 1])


In [9]:
sentences = [train_dataset[2418][0], train_dataset[2419][0]]
labels = [train_dataset[2418][1], train_dataset[2419][1]]
print(f"original labels size: {[len(x) for x in labels]}")
tokenized_inputs_list = [
    tokenizer(sentence, is_split_into_words=True, truncation=True, return_tensors="pt")
    for sentence in sentences
]
input_ids = [x["input_ids"][0] for x in tokenized_inputs_list]
print([len(x) for x in input_ids])
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = [x["attention_mask"][0] for x in tokenized_inputs_list]
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
    bert_output = bert(input_ids, attention_mask)
bert_embeddings = bert_output["last_hidden_state"]
word_ids = [x.word_ids() for x in tokenized_inputs_list]
pooled_embeddings = pooling_embedding(word_ids, bert_embeddings)
batch_embeddings = pad_sequence(pooled_embeddings, batch_first=True, padding_value=0)
# print(batch_embeddings.shape)
B = 2
T = batch_embeddings.shape[1]
labels = [x[:T] for x in labels]
sentences = [x[:T] for x in sentences]
batch_labels = pad_sequence(labels, batch_first=True, padding_value=-100)
L = max(len(word) for sentence in sentences for word in sentence)
batch_char_ids = torch.zeros(B, T, L, dtype=int)
for i in range(B):
    for j in range(len(sentences[i])):
        cur_word = sentences[i][j]
        for k in range(len(cur_word)):
            try: 
                batch_char_ids[i][j][k] = chars2idx.get(cur_word[k], chars2idx["<unk>"])
            except:
                print(i, j, k)

original labels size: [9, 3]
[11, 5]


# BiLSTM_CFR model training and evaluation

In [10]:
class BiLSTM_CRF(nn.Module):
    def __init__(
            self, input_size,
            lstm_hidden_dim, lstm_num_layers, lstm_dropout,
            linear_output_dim, label_size
    ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            dropout=lstm_dropout,
            bidirectional=True
        )
        self.fc = nn.Linear(2 * lstm_hidden_dim, label_size)
        self.crf = CRF(label_size)

    def forward(self, x, mask=None):
        # LSTM layer
        lstm_out, _ = self.lstm(x)

        # Linear layer
        emissions = self.fc(lstm_out)

        if mask is not None:
            emissions = emissions[mask]  # Apply the mask if provided

        # CRF layer
        # if self.training:
        return emissions  # During training, return emissions for CRF training
        # else:
            # tags = self.crf.decode(emissions)  # During evaluation, decode the tags
            # return tags

In [11]:
train_dataset = MultilingualDataset(sentences_train, tags_train)
dev_dataset = MultilingualDataset(sentences_dev, tags_dev)
test_dataset = MultilingualDataset(sentences_test, tags_test)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collate_fn
)
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)
lstm_model = BiLSTM_CRF(
    input_size=768,
    lstm_hidden_dim=256, lstm_num_layers=2, lstm_dropout=0.33,
    linear_output_dim=len(idx2tags), label_size=len(idx2tags)
)
lstm_model = lstm_model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = Adam(lstm_model.parameters(), lr=1e-3)
epochs = 5

In [16]:
best_loss = torch.inf

for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}/{epochs}')
    train_loss = 0
    cur_total = 0
    correct = 0
    amount = 0
    lstm_model.train()
    train_loop = tqdm(train_loader, desc=f'Epoch: {epoch + 1}/{epochs}')
    for bert_embeddings, labels, char_ids in train_loop:
        B = bert_embeddings.shape[0]
        bert_embeddings = bert_embeddings.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = lstm_model(bert_embeddings)
        logits_reshape = logits.view(-1, logits.shape[-1])
        labels_reshape = labels.view(-1)
        loss = criterion(logits_reshape, labels_reshape)
        loss.backward()
        optimizer.step()

        train_loss += loss.cpu().item() * B

        mask = labels_reshape != -100
        
        logits_non_pad = logits_reshape[mask]
        labels_non_pad = labels_reshape[mask]

        _, predictions = torch.max(logits_non_pad, 1)
        correct += sum(predictions == labels_non_pad)
        amount += len(labels_non_pad)

        cur_total += B
        running_loss = train_loss / cur_total
        running_acc = correct / amount
        train_loop.set_postfix(loss=running_loss, acc=running_acc.item())

    train_loss /= len(train_dataset)
    train_acc = correct / amount

    dev_loss = 0
    y_true, y_pred = [], []
    correct = 0
    amount = 0
    lstm_model.eval()
    dev_loop = tqdm(dev_loader, desc=f'Epoch: {epoch + 1}/{epochs}')
    with torch.no_grad():
        for bert_embeddings, labels, char_ids in dev_loop:
            B = bert_embeddings.shape[0]
            bert_embeddings = bert_embeddings.to(device)
            labels = labels.to(device)

            logits = lstm_model(bert_embeddings)
            logits_reshape = logits.view(-1, logits.shape[-1])
            labels_reshape = labels.view(-1)
            loss = criterion(logits_reshape, labels_reshape)
            dev_loss += loss.cpu().item() * B
            
            mask = labels_reshape != -100
            logits_non_pad = logits_reshape[mask]
            labels_non_pad = labels_reshape[mask]

            _, predictions = torch.max(logits_non_pad, 1)
            correct += sum(predictions == labels_non_pad)
            amount += len(labels_non_pad)
            y_pred.extend(predictions.cpu())
            y_true.extend(labels_non_pad.cpu())

    dev_loss /= len(dev_dataset)
    val_acc = correct / amount
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    print('train_loss: {:.4f}, train_acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'.format(train_loss, train_acc, dev_loss, val_acc))
    print('val_precision: {:.4f}, val_recall: {:.4f}, val_f1: {:.4f}'.format(val_precision, val_recall, val_f1))

    if dev_loss < best_loss:
        print(f"update dev_loss from {best_loss} -> {dev_loss}")
        best_loss = dev_loss
        torch.save(lstm_model.state_dict(), "best_lstm_model.pt")

Epoch: 1/5


Epoch: 1/5: 100%|██████████| 40050/40050 [21:25<00:00, 31.17it/s, acc=0.865, loss=0.404] 
Epoch: 1/5: 100%|██████████| 20050/20050 [06:59<00:00, 47.82it/s]


train_loss: 0.4036, train_acc: 0.8652, val_loss: 0.2976, val_acc: 0.8866
val_precision: 0.8604, val_recall: 0.7897, val_f1: 0.8195
update dev_loss from inf -> 0.29756352573061406
Epoch: 2/5


Epoch: 2/5: 100%|██████████| 40050/40050 [20:52<00:00, 31.97it/s, acc=0.889, loss=0.327]
Epoch: 2/5: 100%|██████████| 20050/20050 [06:52<00:00, 48.61it/s]


train_loss: 0.3269, train_acc: 0.8886, val_loss: 0.2731, val_acc: 0.8950
val_precision: 0.8638, val_recall: 0.8103, val_f1: 0.8331
update dev_loss from 0.29756352573061406 -> 0.2731456902275542
Epoch: 3/5


Epoch: 3/5: 100%|██████████| 40050/40050 [20:50<00:00, 32.03it/s, acc=0.896, loss=0.303]
Epoch: 3/5: 100%|██████████| 20050/20050 [06:51<00:00, 48.77it/s]


train_loss: 0.3029, train_acc: 0.8958, val_loss: 0.2605, val_acc: 0.9011
val_precision: 0.8480, val_recall: 0.8492, val_f1: 0.8483
update dev_loss from 0.2731456902275542 -> 0.26049925547421166
Epoch: 4/5


Epoch: 4/5: 100%|██████████| 40050/40050 [20:28<00:00, 32.59it/s, acc=0.901, loss=0.285]
Epoch: 4/5: 100%|██████████| 20050/20050 [06:41<00:00, 49.93it/s]


train_loss: 0.2849, train_acc: 0.9008, val_loss: 0.2781, val_acc: 0.9029
val_precision: 0.8546, val_recall: 0.8441, val_f1: 0.8487
Epoch: 5/5


Epoch: 5/5: 100%|██████████| 40050/40050 [20:31<00:00, 32.53it/s, acc=0.902, loss=0.277]
Epoch: 5/5: 100%|██████████| 20050/20050 [06:46<00:00, 49.31it/s]


train_loss: 0.2765, train_acc: 0.9024, val_loss: 0.2533, val_acc: 0.9059
val_precision: 0.8656, val_recall: 0.8414, val_f1: 0.8530
update dev_loss from 0.26049925547421166 -> 0.2532983063253707


In [12]:
state_dict = torch.load("best_lstm_model.pt")
lstm_model.load_state_dict(state_dict)

<All keys matched successfully>

In [32]:
seqeval = evaluate.load(r".\evaluate-main\evaluate-main\metrics\seqeval")

In [26]:
dev_predictions = []
dev_references = []
lstm_model.eval()
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=1,   # use batch_size=1 => no padding needed
    shuffle=False,
    collate_fn=collate_fn
)
for bert_embeddings, labels, char_ids in tqdm(dev_loader):
    bert_embeddings = bert_embeddings.to(device)
    with torch.no_grad():
        logits = lstm_model(bert_embeddings)
    logit = logits[0].cpu()
    label = labels[0]
    _, prediction = torch.max(logit, dim=1)
    assert len(prediction) == len(label)
    prediction_tag, reference_tag = [], []
    for p_idx, r_idx in zip(prediction, label):
        prediction_tag.append(idx2tags[p_idx.item()])
        reference_tag.append(idx2tags[r_idx.item()])
    dev_predictions.append(prediction_tag)
    dev_references.append(reference_tag)

100%|██████████| 40100/40100 [10:24<00:00, 64.18it/s]


In [31]:
dev_result = seqeval.compute(predictions=dev_predictions, references=dev_references)
dev_result

{'LOC': {'precision': 0.7927988083620114,
  'recall': 0.8008197571858462,
  'f1': 0.7967890973853341,
  'number': 19274},
 'ORG': {'precision': 0.6547136756815942,
  'recall': 0.7437189054726369,
  'f1': 0.6963838583823444,
  'number': 16080},
 'PER': {'precision': 0.8282334480053111,
  'recall': 0.8289338568408335,
  'f1': 0.8285835044076801,
  'number': 16555},
 'overall_precision': 0.7571633765468474,
 'overall_recall': 0.7920977094530813,
 'overall_f1': 0.77423667535989,
 'overall_accuracy': 0.9060073944565635}

In [33]:
test_predictions = []
test_references = []
lstm_model.eval()
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1,   # use batch_size=1 => no padding needed
    shuffle=False,
    collate_fn=collate_fn
)
for bert_embeddings, labels, char_ids in tqdm(test_loader):
    bert_embeddings = bert_embeddings.to(device)
    with torch.no_grad():
        logits = lstm_model(bert_embeddings)
    logit = logits[0].cpu()
    label = labels[0]
    _, prediction = torch.max(logit, dim=1)
    assert len(prediction) == len(label)
    prediction_tag, reference_tag = [], []
    for p_idx, r_idx in zip(prediction, label):
        prediction_tag.append(idx2tags[p_idx.item()])
        reference_tag.append(idx2tags[r_idx.item()])
    test_predictions.append(prediction_tag)
    test_references.append(reference_tag)

100%|██████████| 40100/40100 [16:05<00:00, 41.53it/s]


In [34]:
test_result = seqeval.compute(predictions=test_predictions, references=test_references)
test_result

{'LOC': {'precision': 0.7791982000409081,
  'recall': 0.7786805662016455,
  'f1': 0.7789392971246006,
  'number': 19569},
 'ORG': {'precision': 0.6256896013323618,
  'recall': 0.7112767719796473,
  'f1': 0.6657437146970872,
  'number': 16902},
 'PER': {'precision': 0.8182443272977309,
  'recall': 0.8302325581395349,
  'f1': 0.8241948516680134,
  'number': 17200},
 'overall_precision': 0.7388566753228274,
 'overall_recall': 0.7739747722233609,
 'overall_f1': 0.7560081169865231,
 'overall_accuracy': 0.9041126316132109}

In [40]:
dev_tokens = []
assert len(dev_dataset) == len(dev_predictions)
for i in range(len(dev_dataset)):
    token = dev_dataset[i][0]
    prediction = dev_predictions[i]
    assert len(token) >= len(prediction)
    if len(token) != len(prediction):
        token = token[:len(prediction)]
    dev_tokens.append(token)

test_tokens = []
assert len(test_dataset) == len(test_predictions)
for i in range(len(test_dataset)):
    token = test_dataset[i][0]
    prediction = test_predictions[i]
    assert len(token) >= len(prediction)
    if len(token) != len(prediction):
        token = token[:len(prediction)]
    test_tokens.append(token)

dev_df = pd.DataFrame({
    "tokens": dev_tokens,
    "predictions": dev_predictions
})
test_df = pd.DataFrame({
    "tokens": test_tokens,
    "predictions": test_predictions
})

torch.save(lstm_model.state_dict(), "trained_model_BiLSTM_CRF.bin")
dev_df.to_csv("BiLSTM_CRF_dev.csv")
test_df.to_csv("BiLSTM_CRF_test.csv")