In [64]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from pandas import read_parquet
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

In [3]:
class BiLSTM(nn.Module):
    def __init__(
            self, 
            vocab_size, embedding_dim, 
            lstm_hidden_dim, lstm_num_layers, lstm_dropout, 
            linear_output_dim, label_size
        ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            dropout=lstm_dropout,
            # no use of dropout with only one layer
            bidirectional=True
        )
        self.fc1 = nn.Linear(2 * lstm_hidden_dim, linear_output_dim)
        self.dropout = nn.Dropout(lstm_dropout)
        self.fc2 = nn.Linear(linear_output_dim, label_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.dropout(F.elu(self.fc1(x), inplace=True))
        x = self.fc2(x)
        return x

In [19]:
train_data = read_parquet("data/merge/train.parquet").sample(5000)
dev_data = read_parquet("data/merge/dev.parquet").sample(1000)

In [36]:
train_text = train_data["tokens"].values.tolist()
train_tags = train_data["ner_tags"].values.tolist()
dev_text = dev_data["tokens"].values.tolist()
dev_tags = dev_data["ner_tags"].values.tolist()

In [51]:
def generate_mappings(train_text):
    train_words = set([word for sentence in train_text for word in sentence])
    word2idx = {word: idx + 2 for idx, word in enumerate(train_words)}
    word2idx['<pad>'] = 0  # use <pad> for padding word
    word2idx['<unk>'] = 1   # use <unk> for unknown words and low frequency words
    return word2idx

In [52]:
word2idx = generate_mappings(train_text)

In [54]:
class CustomizedDataset(Dataset):
    def __init__(self, word2idx, sentences, labels) -> None:
        super().__init__()
        self.sentences = sentences
        self.labels = labels
        self.word2idx = word2idx
        self.numerize()

    def numerize(self):
        self.num_data = []
        for sentence in self.sentences:
            num_words = [self.word2idx.get(word, self.word2idx['<unk>']) for word in sentence]
            self.num_data.append(num_words)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        word_seq = self.num_data[index]
        tag_seq = self.labels[index]
        return torch.tensor(word_seq), torch.tensor(tag_seq)
    
def collate_fn(batch):
    word_seqs, tag_seqs = zip(*batch)
    word_seqs = pad_sequence(word_seqs, batch_first=True, padding_value=0)
    tag_seqs = pad_sequence(tag_seqs, batch_first=True, padding_value=9)
    return word_seqs, tag_seqs

In [55]:
train_dataset = CustomizedDataset(word2idx, train_text, train_tags)
dev_dataset = CustomizedDataset(word2idx, dev_text, dev_tags)

In [56]:
TRAIN_BATCH_SIZE = 32
DEV_BATCH_SIZE = 16
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
dev_loader = DataLoader(
    dataset=dev_dataset,
    batch_size=DEV_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

In [61]:
vocab_size = len(word2idx)
label_size = 10
embedding_dim = 100
lstm_num_layers = 2
lstm_hidden_dim = 256
lstm_dropout = 0.33
linear_output_dim = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BiLSTM(vocab_size, embedding_dim, lstm_hidden_dim, lstm_num_layers, lstm_dropout, linear_output_dim, label_size)
optimizer = Adam(model.parameters(), lr=3e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=9) 
epochs = 10

In [67]:
model = model.to(device)

best_f1 = 0
# best_loss = torch.inf

for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}/{epochs}')
    train_loss = 0
    correct = 0
    amount = 0
    model.train()
    for data, target in tqdm(train_loader):
        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        pred = model(data)
        pred = pred.view(-1, label_size)
        target = target.view(-1)
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.cpu().item() * data.shape[0]
        target = target.cpu()
        pred = torch.argmax(pred, dim=1).cpu()
        mask = target != 9
        masked_pred = pred[mask]
        masked_target = target[mask]
        correct += sum(masked_pred == masked_target)
        amount += sum(mask)

    train_loss /= len(train_dataset)
    train_acc = correct / amount

    val_loss = 0
    y_true, y_pred = [], []
    correct = 0
    amount = 0
    model.eval()
    with torch.no_grad():
        for data, target in tqdm(dev_loader):
            data = data.to(device)
            target = target.to(device)

            pred = model(data)
            pred = pred.view(-1, label_size)
            target = target.view(-1)
            loss = criterion(pred, target)

            val_loss += loss.cpu().item() * data.shape[0]

            target = target.cpu()
            pred = torch.argmax(pred, dim=1).cpu()
            mask = target != 9
            masked_pred = pred[mask]
            masked_target = target[mask]
            correct += sum(masked_pred == masked_target)
            amount += sum(mask)
            y_pred.extend(masked_pred)
            y_true.extend(masked_target)

    val_loss /= len(dev_dataset)
    val_acc = correct / amount
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    print('train_loss: {:.4f}, train_acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'.format(train_loss, train_acc, val_loss, val_acc))
    print('val_precision: {:.4f}, val_recall: {:.4f}, val_f1: {:.4f}'.format(val_precision, val_recall, val_f1))

Epoch: 1/10


100%|██████████| 157/157 [00:25<00:00,  6.10it/s]
100%|██████████| 63/63 [00:03<00:00, 19.37it/s]


train_loss: 0.5235, train_acc: 0.8287, val_loss: 0.6546, val_acc: 0.7887
val_precision: 0.7268, val_recall: 0.5894, val_f1: 0.6153
Epoch: 2/10


100%|██████████| 157/157 [00:28<00:00,  5.53it/s]
100%|██████████| 63/63 [00:03<00:00, 17.32it/s]


train_loss: 0.3858, train_acc: 0.8742, val_loss: 0.7745, val_acc: 0.7552
val_precision: 0.6259, val_recall: 0.6529, val_f1: 0.6245
Epoch: 3/10


100%|██████████| 157/157 [00:29<00:00,  5.34it/s]
100%|██████████| 63/63 [00:03<00:00, 19.99it/s]


train_loss: 0.3213, train_acc: 0.8938, val_loss: 0.6637, val_acc: 0.7951
val_precision: 0.6661, val_recall: 0.6649, val_f1: 0.6608
Epoch: 4/10


100%|██████████| 157/157 [00:27<00:00,  5.71it/s]
100%|██████████| 63/63 [00:03<00:00, 16.84it/s]


train_loss: 0.2402, train_acc: 0.9196, val_loss: 0.6631, val_acc: 0.8151
val_precision: 0.7001, val_recall: 0.6733, val_f1: 0.6818
Epoch: 5/10


100%|██████████| 157/157 [00:37<00:00,  4.24it/s]
100%|██████████| 63/63 [00:04<00:00, 14.61it/s]


train_loss: 0.2250, train_acc: 0.9255, val_loss: 0.6271, val_acc: 0.8154
val_precision: 0.6868, val_recall: 0.6785, val_f1: 0.6790
Epoch: 6/10


100%|██████████| 157/157 [00:29<00:00,  5.24it/s]
100%|██████████| 63/63 [00:03<00:00, 19.00it/s]


train_loss: 0.1950, train_acc: 0.9370, val_loss: 0.7061, val_acc: 0.7957
val_precision: 0.6729, val_recall: 0.6804, val_f1: 0.6718
Epoch: 7/10


100%|██████████| 157/157 [00:30<00:00,  5.14it/s]
100%|██████████| 63/63 [00:03<00:00, 20.27it/s]


train_loss: 0.1761, train_acc: 0.9422, val_loss: 0.6991, val_acc: 0.8258
val_precision: 0.7241, val_recall: 0.6627, val_f1: 0.6900
Epoch: 8/10


100%|██████████| 157/157 [00:26<00:00,  6.02it/s]
100%|██████████| 63/63 [00:03<00:00, 19.44it/s]


train_loss: 0.1491, train_acc: 0.9514, val_loss: 0.7092, val_acc: 0.8182
val_precision: 0.7003, val_recall: 0.6886, val_f1: 0.6917
Epoch: 9/10


100%|██████████| 157/157 [00:27<00:00,  5.75it/s]
100%|██████████| 63/63 [00:03<00:00, 19.05it/s]


train_loss: 0.1490, train_acc: 0.9516, val_loss: 0.7153, val_acc: 0.8222
val_precision: 0.7135, val_recall: 0.6811, val_f1: 0.6957
Epoch: 10/10


100%|██████████| 157/157 [00:27<00:00,  5.80it/s]
100%|██████████| 63/63 [00:03<00:00, 17.54it/s]


train_loss: 0.1325, train_acc: 0.9584, val_loss: 0.7351, val_acc: 0.8293
val_precision: 0.7275, val_recall: 0.6887, val_f1: 0.7028
