## BiLSTM for MIT Movies

In [1]:
import os
import sys
import json
import numpy as np
from tqdm import tqdm
sys.path.append("..")

from torch import nn
from torch.optim import Adam
from src.namedentityrecognizer.models.lstm import BiLSTM
from src.namedentityrecognizer.trainers import TrainerBiLstm
from src.namedentityrecognizer.utils.processors import NerPreProcessor
from src.namedentityrecognizer.data.build_dataset import Corpus, BuildData

In [2]:
# For finding the absolute path dynamically for every other user for the sake of this notebooks paths
for path in globals()['_dh']:
    if "NamedEntityRecognizer" in path.split(os.sep):
        absolute_path = path
        break
print(absolute_path)

/home/karaz/Desktop/NamedEntityRecognizer


In [3]:
# Create datasets with tab as splitter for corpus of torch text to handle - Uncomment if needed -
# Convert ->  O	good             -> to ->  good     O           
# Convert ->  B-GENRE	romantic -> to ->  romantic B-GENRE
# Convert ->  I-GENRE	comedies -> to ->  comedies I-GENRE
BuildData.create_finaldata(os.path.join(absolute_path, "data/raw/mitmovies/engtrain.bio"), os.path.join(absolute_path, "data/modified/mitmovies_tab_format/train.txt"), splits="\t")
BuildData.create_finaldata(os.path.join(absolute_path, "data/raw/mitmovies/engtest.bio"), os.path.join(absolute_path, "data/modified/mitmovies_tab_format/test.txt"), splits="\t")

In [4]:
# Dataset class
dataset = Corpus(
    input_folder=os.path.join(absolute_path, "data/modified/mitmovies_tab_format"),
    min_word_freq=3,  # any words occurring less than 3 times will be ignored from vocab
    batch_size=64)
print(f"Train set: {len(dataset.train_dataset)} sentences")
print(f"Test set: {len(dataset.test_dataset)} sentences")

Train set: 9775 sentences
Test set: 2443 sentences


In [5]:
model = BiLSTM(
    input_dim=len(dataset.word_field.vocab),
    embedding_dim=300,
    hidden_dim=64,
    output_dim=len(dataset.tag_field.vocab),
    lstm_layers=4,
    emb_dropout=0.25,
    lstm_dropout=0.01,
    fc_dropout=0.1,
    word_pad_idx=dataset.word_pad_idx,
)
# Initialize weights and embeddings
model.init_weights()
model.init_embeddings(word_pad_idx=dataset.word_pad_idx)
print(f"The model has {model.count_parameters():,} trainable parameters.")
print(model)

The model has 1,161,930 trainable parameters.
BiLSTM(
  (embedding): Embedding(2244, 300, padding_idx=1)
  (emb_dropout): Dropout(p=0.25, inplace=False)
  (lstm): LSTM(300, 64, num_layers=4, dropout=0.01, bidirectional=True)
  (fc_dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=128, out_features=26, bias=True)
)


In [6]:
ner = TrainerBiLstm(
  model=model,
  data=dataset,
  optimizer_cls=Adam,
  loss_fn_cls=nn.CrossEntropyLoss,
  log_name="bilstm_vanilla2"
)
ner.train(10)

Epoch: 01 | Epoch Time: 0m 15s
	Trn Loss: 1.573 | Trn Acc: 62.89%
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
	Val Loss: 1.092 | Val Acc: 69.93% | Val Precision: 34.57% | Val Recall: 36.35% | Val F1 Macro: 29.80% | Val F1 Micro: 71.98%
Epoch: 02 | Epoch Time: 0m 15s
	Trn Loss: 0.660 | Trn Acc: 82.62%
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifi

In [7]:
ner.infer("list an r rated drama movie")

word 	unk  	pred tag
list 	     	O       
an   	     	O       
r    	     	B-RATING
rated	     	O       
drama	     	B-GENRE 
movie	     	O       


(['list', 'an', 'r', 'rated', 'drama', 'movie'],
 ['O', 'O', 'B-RATING', 'O', 'B-GENRE', 'O'],
 [])