# packages

In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas import read_parquet
from transformers import BertModel, BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support

  from .autonotebook import tqdm as notebook_tqdm


# dataset reading

In [2]:
model_path = "data/mBERT/fine"

tokenizer = BertTokenizerFast.from_pretrained(model_path)

In [3]:
train_data = read_parquet("data/merge/train.parquet")
dev_data = read_parquet("data/merge/dev.parquet")
test_data = read_parquet("data/merge/test.parquet")

with open("data/merge/tags_2_idx.json", "r") as f:
    tags2idx = json.load(f)

with open("data/merge/idx_2_tags.json", "r") as f:
    idx2tags = json.load(f)

with open("data/merge/chars2idx.json", "r") as f:
    chars2idx = json.load(f)

In [4]:
sentences_train = train_data["tokens"].values.tolist()
tags_train = train_data["ner_tags"].values.tolist()

sentences_dev = dev_data["tokens"].values.tolist()
tags_dev = dev_data["ner_tags"].values.tolist()

sentences_test = test_data["tokens"].values.tolist()
tags_test = test_data["ner_tags"].values.tolist()

In [5]:
MAX_WORD_LEN = 30
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def align_label(tokenized_input, tags, tags_2_idx, idx_2_tags, label_all_tokens=True): 
    # tokenized_input refers to the sequences after tokenized
    # tags refers to the original tags from dataset
    # False:只为每个拆分token的第一个子词提供一个标签。
    # True:在属于同一 token 的所有子词中提供相同的标签。
    word_ids = tokenized_input.word_ids()
    previous_word_idx = None
    label_ids = []   
    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)                
        elif word_idx != previous_word_idx:
            try:
                label_ids.append(tags[word_idx])
            except:
                label_ids.append(-100) 
        else:
            label_ids.append(tags[word_idx] if label_all_tokens else -100)
        previous_word_idx = word_idx      
    return label_ids

def generate_tokenized_input(sentences_raw, tags_raw):
    sentences = []
    tags = []
    for i in range(len(sentences_raw)):
        tokenized_text = tokenizer(sentences_raw[i].tolist(), padding="max_length", max_length=512, truncation=True, return_tensors="pt", is_split_into_words=True)
        extended_tags = align_label(tokenized_text, tags_raw[i], tags2idx, idx2tags)
        sentences.append(tokenized_text)
        tags.append(extended_tags)
    return sentences, tags

def generate_tokenized_input_with_words(sentences_raw, tags_raw):
    sentences = []
    tags = []
    words = []
    chars = []
    for i in range(len(sentences_raw)):
        tokenized_text = tokenizer(sentences_raw[i].tolist(), padding="max_length", max_length=512, truncation=True, return_tensors="pt", is_split_into_words=True)
        extended_tags = align_label(tokenized_text, tags_raw[i], tags2idx, idx2tags)
        sentences.append(tokenized_text)
        tags.append(extended_tags)
        token_ids = tokenized_text["input_ids"][0]
        token_words = tokenizer.convert_ids_to_tokens(token_ids)
        words.append(token_words)
        char_ids = torch.zeros(512, MAX_WORD_LEN)
        for i in range(len(token_words)):
            for j in range(len(token_words[i])):
                char_ids[i][j] = chars2idx.get(token_words[i][j], chars2idx['<unk>'])
        chars.append(char_ids)
    return sentences, tags, words, chars

In [7]:
train_sentences, train_tags = generate_tokenized_input(sentences_train, tags_train)
dev_sentences, dev_tags = generate_tokenized_input(sentences_dev, tags_dev)
test_sentences, test_tags = generate_tokenized_input(sentences_test, tags_test)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# train_sentences1, train_tags1, train_words, train_chars = generate_tokenized_input_with_words(sentences_train, tags_train)

In [10]:
bert = BertModel.from_pretrained("data/mBERT/fine").to(device)

Some weights of the model checkpoint at data/mBERT/fine were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at data/mBERT/fine and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
class MultilingualDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def get_text_tokenized(self, idx):
        return self.sentences[idx]

    def get_labels(self, idx):
        return torch.LongTensor(self.labels[idx])

    def __getitem__(self, idx):
        text_tokenized = self.get_text_tokenized(idx)
        labels = self.get_labels(idx)
        return text_tokenized, labels.unsqueeze(0)  # shap: [1, 512]
    
def collate_fn(batch):
    text_tokenized_seqs, labels_seqs = zip(*batch)
    B = len(labels_seqs)
    batch_input_ids = []
    batch_attention_masks = []
    batch_label_seqs = torch.concat(labels_seqs)
    for i in range(B):
        batch_input_ids.append(text_tokenized_seqs[i]["input_ids"])
        batch_attention_masks.append(text_tokenized_seqs[i]["attention_mask"])
    batch_input_ids = torch.concat(batch_input_ids).to(device)
    batch_attention_masks = torch.concat(batch_attention_masks).to(device)
    with torch.no_grad():
        bert_output = bert(batch_input_ids, batch_attention_masks)
    bert_embeddings = bert_output["last_hidden_state"]
    return bert_embeddings, batch_label_seqs, batch_attention_masks

In [13]:
train_dataset = MultilingualDataset(train_sentences, train_tags)
dev_dataset = MultilingualDataset(dev_sentences, dev_tags)
test_dataset = MultilingualDataset(test_sentences, test_tags)

# model class

In [14]:
#output = bert(train_sentences[0]["input_ids"], attention_mask=train_sentences[0]["attention_mask"])

In [15]:
class BiLSTM(nn.Module):
    def __init__(
            self, input_size,
            lstm_hidden_dim, lstm_num_layers, lstm_dropout, 
            linear_output_dim, label_size
        ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            dropout=lstm_dropout,
            bidirectional=True
        )
        self.fc1 = nn.Linear(2 * lstm_hidden_dim, linear_output_dim)
        self.dropout = nn.Dropout(lstm_dropout)
        self.fc2 = nn.Linear(linear_output_dim, label_size)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.dropout(F.elu(self.fc1(x), inplace=True))
        x = self.fc2(x)
        return x

In [16]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)
lstm_model = BiLSTM(
    input_size=768,
    lstm_hidden_dim=256, lstm_num_layers=2, lstm_dropout=0.33,
    linear_output_dim=128, label_size=len(idx2tags)
)
lstm_model = lstm_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(lstm_model.parameters(), lr=1e-3)
epochs = 5

In [None]:
for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}/{epochs}')
    train_loss = 0
    cur_total = 0
    correct = 0
    amount = 0
    lstm_model.train()
    train_loop = tqdm(train_loader, desc=f'Epoch: {epoch + 1}/{epochs}')
    for bert_embeddings, labels, attention_masks in train_loop:
        bert_embeddings = bert_embeddings.to(device)
        labels = labels.to(device)
        attention_masks = attention_masks.to(device)

        optimizer.zero_grad()
        logits = lstm_model(bert_embeddings)

        logits_reshape = logits.view(-1, logits.shape[-1])  # [B*T, label_size]
        labels_reshape = labels.view(-1)    # [B*T]
        reshape_masks = attention_masks.view(-1)    # [B*T]
        logits_non_pad = logits_reshape[reshape_masks == 1]
        labels_non_pad = labels_reshape[reshape_masks == 1]

        loss = criterion(logits_non_pad, labels_non_pad)
        loss.backward()
        optimizer.step()

        train_loss += loss.cpu().item() * len(labels)

        _, predictions = torch.max(logits_non_pad, 1)
        correct += sum(predictions == labels_non_pad)
        amount += len(labels_non_pad)

        cur_total += len(labels)
        running_loss = train_loss / cur_total
        running_acc = correct / amount
        train_loop.set_postfix(loss=running_loss, acc=running_acc.item())

    train_loss /= len(train_dataset)
    train_acc = correct / amount

    dev_loss = 0
    y_true, y_pred = [], []
    correct = 0
    amount = 0
    lstm_model.eval()
    with torch.no_grad():
        for bert_embeddings, labels, attention_masks in tqdm(dev_loader):
            bert_embeddings = bert_embeddings.to(device)
            labels = labels.to(device)
            attention_masks = attention_masks.to(device)

            logits = lstm_model(bert_embeddings)

            logits_reshape = logits.view(-1, logits.shape[-1])  # [B*T, label_size]
            labels_reshape = labels.view(-1)    # [B*T]
            reshape_masks = attention_masks.view(-1)    # [B*T]
            logits_non_pad = logits_reshape[reshape_masks == 1]
            labels_non_pad = labels_reshape[reshape_masks == 1]

            loss = criterion(logits_non_pad, labels_non_pad)

            dev_loss += loss.cpu().item() * len(labels)

            _, predictions = torch.max(logits_non_pad, 1)
            correct += sum(predictions == labels_non_pad)
            amount += len(labels_non_pad)
            y_pred.extend(predictions.cpu())
            y_true.extend(labels_non_pad.cpu())

    dev_loss /= len(dev_dataset)
    val_acc = correct / amount
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    print('train_loss: {:.4f}, train_acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'.format(train_loss, train_acc, dev_loss, val_acc))
    print('val_precision: {:.4f}, val_recall: {:.4f}, val_f1: {:.4f}'.format(val_precision, val_recall, val_f1))

Epoch: 1/5


Epoch: 1/5:  77%|████████████████████████████████▍         | 7727/10013 [41:33<12:13,  3.12it/s, acc=0.675, loss=0.775]

# CSV output