## Import Libraries

In [None]:
import torch
import torch.nn as nn
from torchcrf import CRF
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from sklearn.metrics import classification_report

from datasets import load_from_disk
from reference_parsing.config import (LABEL2ID)

from reference_parsing.utils.data_preparation import prepare_bilstm_crf_data, collate_bilstm_crf
from reference_parsing.embeddings.HandFeatureEmbedding import  HandFeatureEmbedding
from reference_parsing.embeddings.BytePairReferenceEmbedding import BytePairReferenceEmbedding

  from .autonotebook import tqdm as notebook_tqdm


## Load Dataset

In [2]:
prepared_dataset = load_from_disk("./datasets/prepared_dataset")

In [3]:
bp_emb = BytePairReferenceEmbedding()
hand_emb = HandFeatureEmbedding()

In [None]:
class ReferenceDataset(Dataset):
    def __init__(self, X_bpe, X_hand, Y):
        self.X_bpe = X_bpe
        self.X_hand = X_hand
        self.Y = Y
    
    def __len__(self):
        return len(self.X_bpe)
    
    def __getitem__(self, idx):
        return self.X_bpe[idx], self.X_hand[idx], self.Y[idx]

In [None]:
label2id = LABEL2ID
id2label = {v: k for k, v in label2id.items()}

## Data Segmentation

In [None]:
train_5mil = prepared_dataset["train"].select(range(5000000))
valid_5mil = prepared_dataset["valid"].select(range(200000))
test_5mil = prepared_dataset["test"].select(range(200000)) 

In [None]:
X_bpe, X_hand, Y = prepare_bilstm_crf_data(bp_emb, hand_emb, train_5mil)
train_ds = ReferenceDataset(X_bpe, X_hand, Y)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,
                          collate_fn=lambda batch: collate_bilstm_crf(batch, label2id, hand_emb.get_hand_feature_vocab()))

In [None]:
X_bpe, X_hand, Y = prepare_bilstm_crf_data(bp_emb, hand_emb, valid_5mil)
valid_ds = ReferenceDataset(X_bpe, X_hand, Y)
valid_loader = DataLoader(valid_ds, batch_size=32, shuffle=True,
                          collate_fn=lambda batch: collate_bilstm_crf(batch, label2id, hand_emb.get_hand_feature_vocab()))

In [None]:
X_bpe, X_hand, Y = prepare_bilstm_crf_data(bp_emb, hand_emb, test_5mil)
test_ds = ReferenceDataset(X_bpe, X_hand, Y)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=True,
                          collate_fn=lambda batch: collate_bilstm_crf(batch, label2id, hand_emb.get_hand_feature_vocab()))

## BiLSTM - Hand Features

### Model Coding

In [None]:
class BiLSTMCRFModel(nn.Module):
    def __init__(self, bpe_dim, lstm_hidden_dim, hand_vocab_size, hand_emb_dim, num_tags, dropout_rate=0.5):
        super(BiLSTMCRFModel, self).__init__()
        self.lstm = nn.LSTM(input_size=bpe_dim, hidden_size=lstm_hidden_dim,
                            batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.hand_emb = nn.Embedding(hand_vocab_size, hand_emb_dim)
        self.hand_proj = nn.Linear(hand_emb_dim, 2 * lstm_hidden_dim)
        self.classifier = nn.Linear(2 * lstm_hidden_dim, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, bpe_embeddings, hand_indices, tags=None, mask=None):
        lstm_out, _ = self.lstm(bpe_embeddings)
        lstm_out = self.dropout(lstm_out)
        hand_emb_out = self.hand_emb(hand_indices)
        hand_emb_out = self.dropout(hand_emb_out)
        hand_proj = self.hand_proj(hand_emb_out)
        fused = lstm_out + hand_proj
        fused = self.dropout(fused)
        emissions = self.classifier(fused)
        if tags is not None:
            loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
            return loss, emissions
        else:
            predictions = self.crf.decode(emissions, mask=mask)
            return predictions


In [None]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs, device):
    best_val_loss = float('inf')
    model.train()
    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0
        for batch in train_loader:
            bpe_inputs, hand_inputs, tags, mask = batch
            bpe_inputs = bpe_inputs.to(device)
            hand_inputs = hand_inputs.to(device)
            tags = tags.to(device)
            mask = mask.to(device)
            
            optimizer.zero_grad()
            loss, _ = model(bpe_inputs, hand_inputs, tags, mask)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}/{num_epochs}: Training Loss = {avg_loss:.4f}")
        
        # Validation step
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                bpe_inputs, hand_inputs, tags, mask = batch
                bpe_inputs = bpe_inputs.to(device)
                hand_inputs = hand_inputs.to(device)
                tags = tags.to(device)
                mask = mask.to(device)
                loss, _ = model(bpe_inputs, hand_inputs, tags, mask)
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch}/{num_epochs}: Validation Loss = {avg_val_loss:.4f}")
        
        # Save the model only if the validation loss improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f"models/bilstm_5mil/bilstm_crf_best.pt")
            print(f"Model saved at epoch {epoch} with validation loss {avg_val_loss:.4f}")
        
        model.train()
        
    return model


In [None]:
def flatten(sequences):
    return [label for seq in sequences for label in seq]

def evaluate_and_report(model, test_loader, label2id, id2label):
    model.eval()
    all_predictions = []
    all_true = []
    with torch.no_grad():
        for batch in test_loader:
            bpe_inputs, hand_inputs, tags, mask = batch
            predictions = model(bpe_inputs, hand_inputs, mask=mask)
            all_predictions.extend(predictions)
            all_true.extend(tags.cpu().tolist())
    
    y_pred_flat = flatten(all_predictions)
    y_true_flat = flatten(all_true)

    y_true_flat_filtered = [x for x in y_true_flat if x != -1]
    y_pred_flat_filtered = [x for x in y_pred_flat if x != -1]

    y_true_labels = [id2label[idx] for idx in y_true_flat_filtered]
    y_pred_labels = [id2label[idx] for idx in y_pred_flat_filtered]

    label_order = list(label2id.keys())

    print("Classification Report:")
    print(classification_report(y_true_labels, y_pred_labels, labels=label_order, zero_division=0))


### Model Training

In [67]:
bpe_dim = 600
lstm_hidden_dim = 128
hand_vocab_size = len(hand_emb.get_hand_feature_vocab())
hand_emb_dim = 50
num_tags = 26
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bilstm_crf_5mil = BiLSTMCRFModel(bpe_dim, lstm_hidden_dim, hand_vocab_size, hand_emb_dim, num_tags)
optimizer = optim.AdamW(bilstm_crf_5mil.parameters(), lr=0.003)

In [68]:
bilstm_crf_5mil = train_model(bilstm_crf_5mil, train_loader, valid_loader, optimizer, 50, device)

Epoch 1/50: Training Loss = 149.4699
Epoch 1/50: Validation Loss = 102.7207
Model saved at epoch 1 with validation loss 102.7207
Epoch 2/50: Training Loss = 101.8513
Epoch 2/50: Validation Loss = 91.2631
Model saved at epoch 2 with validation loss 91.2631
Epoch 3/50: Training Loss = 82.3033
Epoch 3/50: Validation Loss = 85.1576
Model saved at epoch 3 with validation loss 85.1576
Epoch 4/50: Training Loss = 79.9737
Epoch 4/50: Validation Loss = 75.8468
Model saved at epoch 4 with validation loss 75.8468
Epoch 5/50: Training Loss = 71.6131
Epoch 5/50: Validation Loss = 68.7395
Model saved at epoch 5 with validation loss 68.7395
Epoch 6/50: Training Loss = 64.2328
Epoch 6/50: Validation Loss = 61.9641
Model saved at epoch 6 with validation loss 61.9641
Epoch 7/50: Training Loss = 58.1300
Epoch 7/50: Validation Loss = 55.0834
Model saved at epoch 7 with validation loss 55.0834
Epoch 8/50: Training Loss = 51.9631
Epoch 8/50: Validation Loss = 49.4444
Model saved at epoch 8 with validation l

In [None]:
evaluate_and_report(bilstm_crf_5mil, test_loader, label2id, id2label)

                   precision    recall  f1-score      support

         B-AUTHOR       0.86      0.86      0.86       190515
         I-AUTHOR       0.90      0.97      0.93      1735612
           B-YEAR       0.98      0.88      0.93       172556
           I-YEAR       0.98      0.90      0.94         1350
          B-TITLE       0.82      0.90      0.86      1815762
          I-TITLE       0.78      0.99      0.87     18249721
B-CONTAINER-TITLE       0.44      0.40      0.42       133714
I-CONTAINER-TITLE       0.84      0.44      0.58      1100358
         B-VOLUME       0.62      0.98      0.76        39419
         I-VOLUME       0.85      0.95      0.90          155
          B-ISSUE       0.67      0.98      0.80        13557
          I-ISSUE       0.85      0.98      0.91          343
           B-PAGE       0.85      0.98      0.91       157135
           I-PAGE       0.90      0.97      0.93        29258
           B-ISBN       0.97      0.83      0.89        45327
       

## BiLSTM - No Hand Featurs

### Model Coding

In [None]:
class BiLSTMCRFModelNoHand(nn.Module):
    def __init__(self, bpe_dim, lstm_hidden_dim, num_tags, dropout_rate=0.5):
        super(BiLSTMCRFModelNoHand, self).__init__()
        self.lstm = nn.LSTM(input_size=bpe_dim, hidden_size=lstm_hidden_dim,
                            batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(2 * lstm_hidden_dim, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, bpe_embeddings, tags=None, mask=None):
        lstm_out = self.dropout(lstm_out)
        emissions = self.classifier(lstm_out)
        
        if tags is not None:
            loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
            return loss, emissions
        else:
            predictions = self.crf.decode(emissions, mask=mask)
            return predictions

In [52]:
def train_model_no_hand(model, train_loader, val_loader, optimizer, num_epochs, device):
    best_val_loss = float('inf')
    model.train()
    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0
        for batch in train_loader:
            # Each batch now is a tuple: (bpe_embeddings, tags, mask)
            bpe_inputs, _, tags, mask = batch
            bpe_inputs = bpe_inputs.to(device)
            tags = tags.to(device)
            mask = mask.to(device)
            
            optimizer.zero_grad()
            loss, _ = model(bpe_inputs, tags=tags, mask=mask)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}/{num_epochs}: Training Loss = {avg_loss:.4f}")
        
        # Validation step
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                bpe_inputs, _, tags, mask = batch
                bpe_inputs = bpe_inputs.to(device)
                tags = tags.to(device)
                mask = mask.to(device)
                loss, _ = model(bpe_inputs, tags=tags, mask=mask)
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch}/{num_epochs}: Validation Loss = {avg_val_loss:.4f}")
        
        # Save model checkpoint if validation loss improves.
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "models/bilstm_crf_no_hand_best.pt")
            print(f"Model saved at epoch {epoch} with validation loss {avg_val_loss:.4f}")
        
        model.train()
    
    return model

In [None]:
def evaluate_and_report_no_hand(model, test_loader, label2id, id2label):
    model.eval()
    all_predictions = []
    all_true = []
    with torch.no_grad():
        for batch in test_loader:
            bpe_inputs, _, tags, mask = batch
            bpe_inputs = bpe_inputs
            mask = mask
            predictions = model(bpe_inputs, mask=mask)
            all_predictions.extend(predictions)
            all_true.extend(tags.cpu().tolist())
    
    y_pred_flat = flatten(all_predictions)
    y_true_flat = flatten(all_true)
    
    y_true_flat_filtered = [x for x in y_true_flat if x != -1]
    y_pred_flat_filtered = [x for x in y_pred_flat if x != -1]
    
    y_true_labels = [id2label[idx] for idx in y_true_flat_filtered]
    y_pred_labels = [id2label[idx] for idx in y_pred_flat_filtered]

    label_order = list(label2id.keys())
    
    print("Classification Report:")
    print(classification_report(y_true_labels, y_pred_labels, labels=label_order, zero_division=0))

### Model Training

In [54]:
bpe_dim = 600
lstm_hidden_dim = 128
num_tags = 26
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bilstm_crf_5mil_no_hand = BiLSTMCRFModelNoHand(bpe_dim, lstm_hidden_dim, num_tags)
optimizer = optim.AdamW(bilstm_crf_5mil_no_hand.parameters(), lr=0.003)

In [55]:
bilstm_crf_5mil_no_hand = train_model_no_hand(bilstm_crf_5mil_no_hand, train_loader, valid_loader, optimizer, 50, device)

Epoch 1/50: Training Loss = 146.3255
Epoch 1/50: Validation Loss = 123.7872
Model saved at epoch 1 with validation loss 123.7872
Epoch 2/50: Training Loss = 117.0938
Epoch 2/50: Validation Loss = 103.2188
Model saved at epoch 2 with validation loss 103.2188
Epoch 3/50: Training Loss = 91.2532
Epoch 3/50: Validation Loss = 86.7035
Model saved at epoch 3 with validation loss 86.7035
Epoch 4/50: Training Loss = 73.6974
Epoch 4/50: Validation Loss = 75.8566
Model saved at epoch 4 with validation loss 75.8566
Epoch 5/50: Training Loss = 64.7248
Epoch 5/50: Validation Loss = 66.9796
Model saved at epoch 5 with validation loss 66.9796
Epoch 6/50: Training Loss = 57.2982
Epoch 6/50: Validation Loss = 59.8148
Model saved at epoch 6 with validation loss 59.8148
Epoch 7/50: Training Loss = 50.7620
Epoch 7/50: Validation Loss = 54.0641
Model saved at epoch 7 with validation loss 54.0641
Epoch 8/50: Training Loss = 43.3658
Epoch 8/50: Validation Loss = 48.8672
Model saved at epoch 8 with validation

In [None]:
evaluate_and_report_no_hand(bilstm_crf_5mil_no_hand, test_loader, label2id, id2label)

                   precision    recall  f1-score      support

         B-AUTHOR       0.75      0.86      0.80       190515
         I-AUTHOR       0.90      0.97      0.93      1735612
           B-YEAR       0.99      0.89      0.94       172556
           I-YEAR       0.96      0.88      0.92         1350
          B-TITLE       0.75      0.90      0.82      1815762
          I-TITLE       0.76      0.97      0.85     18249721
B-CONTAINER-TITLE       0.72      0.65      0.68       133714
I-CONTAINER-TITLE       0.77      0.60      0.67      1100358
         B-VOLUME       0.80      0.80      0.80        39419
         I-VOLUME       0.88      0.90      0.89          155
          B-ISSUE       0.80      0.99      0.88        13557
          I-ISSUE       0.85      0.95      0.90          343
           B-PAGE       0.99      0.89      0.94       157135
           I-PAGE       0.97      0.90      0.93        29258
           B-ISBN       0.94      0.86      0.90        45327
       