# BiLSTM-CRF NER Exploration

This notebook explores the biomedical NER dataset and trained models.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

from src.utils.vocab import Vocabulary, LabelVocabulary
from src.data.dataset import get_data_statistics
from src.models.bilstm_crf import BiLSTMCRF

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Dataset Exploration

In [None]:
# Get statistics for all splits
train_stats = get_data_statistics('../data/processed/train.txt')
dev_stats = get_data_statistics('../data/processed/dev.txt')
test_stats = get_data_statistics('../data/processed/test.txt')

print("Dataset Statistics:")
print("="*60)
print(f"Train: {train_stats['num_sentences']} sentences")
print(f"Dev:   {dev_stats['num_sentences']} sentences")
print(f"Test:  {test_stats['num_sentences']} sentences")
print("="*60)

In [None]:
# Visualize sentence length distribution
def read_sentence_lengths(filepath):
    lengths = []
    current_length = 0
    
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                if current_length > 0:
                    lengths.append(current_length)
                    current_length = 0
            else:
                current_length += 1
        if current_length > 0:
            lengths.append(current_length)
    
    return lengths

train_lengths = read_sentence_lengths('../data/processed/train.txt')

plt.figure(figsize=(12, 5))
plt.hist(train_lengths, bins=30, edgecolor='black', alpha=0.7)
plt.xlabel('Sentence Length (tokens)')
plt.ylabel('Frequency')
plt.title('Distribution of Sentence Lengths in Training Set')
plt.axvline(np.mean(train_lengths), color='red', linestyle='--', label=f'Mean: {np.mean(train_lengths):.1f}')
plt.legend()
plt.show()

In [None]:
# Label distribution
label_dist = train_stats['label_distribution']

plt.figure(figsize=(10, 6))
labels = list(label_dist.keys())
counts = list(label_dist.values())
colors = ['green' if l == 'O' else 'blue' if 'Chemical' in l else 'red' for l in labels]

plt.bar(labels, counts, color=colors, alpha=0.7, edgecolor='black')
plt.xlabel('Label')
plt.ylabel('Count')
plt.title('Label Distribution in Training Set')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Print percentages
total = sum(counts)
print("\nLabel Percentages:")
for label, count in label_dist.items():
    percentage = (count / total) * 100
    print(f"{label:15s}: {count:6d} ({percentage:5.2f}%)")

## 2. Sample Sentences

In [None]:
def read_sample_sentences(filepath, n=5):
    sentences = []
    current_tokens = []
    current_tags = []
    
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                if current_tokens:
                    sentences.append((current_tokens, current_tags))
                    current_tokens = []
                    current_tags = []
                if len(sentences) >= n:
                    break
            else:
                parts = line.split('\t')
                if len(parts) == 2:
                    current_tokens.append(parts[0])
                    current_tags.append(parts[1])
    
    return sentences

# Display sample sentences with highlighting
samples = read_sample_sentences('../data/processed/train.txt', n=5)

print("Sample Annotated Sentences:")
print("="*80)

for i, (tokens, tags) in enumerate(samples, 1):
    print(f"\nSentence {i}:")
    print("-" * 80)
    
    # Print tokens with tags
    for token, tag in zip(tokens, tags):
        if tag.startswith('B-'):
            entity_type = tag[2:]
            print(f"[{token}", end='')
        elif tag.startswith('I-'):
            print(f" {token}", end='')
        else:  # O tag
            print(f" {token}", end='')
        
        # Close entity if needed
        if tag.startswith('B-') or tag.startswith('I-'):
            # Check if next tag is not continuation
            idx = list(zip(tokens, tags)).index((token, tag))
            if idx == len(tags) - 1 or not tags[idx + 1].startswith('I-'):
                entity_type = tag[2:]
                print(f"/{entity_type}]", end='')
    
    print()  # New line after sentence

print("\n" + "="*80)

## 3. Vocabulary Analysis

In [None]:
# Load vocabularies (if they exist)
try:
    word_vocab = Vocabulary.load('../artifacts/vocab_word.pkl')
    label_vocab = LabelVocabulary.load('../artifacts/vocab_label.pkl')
    
    print(f"Word vocabulary size: {len(word_vocab)}")
    print(f"Label vocabulary: {list(label_vocab.label2idx.keys())}")
except FileNotFoundError:
    print("Vocabularies not found. Run training first to generate vocabularies.")

## 4. Model Analysis (After Training)

In [None]:
# Load trained model (if it exists)
try:
    checkpoint = torch.load('../artifacts/best_model.pt', map_location='cpu')
    
    print("Model Checkpoint Information:")
    print("="*60)
    print(f"Best epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"Best F1 score: {checkpoint.get('best_f1', 'N/A'):.4f}")
    print("="*60)
    
except FileNotFoundError:
    print("Model checkpoint not found. Run training first.")

In [None]:
# Visualize CRF transition matrix (if model is loaded)
try:
    # Get transition matrix from checkpoint
    transitions = checkpoint['model_state_dict']['crf.transitions'].cpu().numpy()
    
    # Get label names
    label_names = list(label_vocab.idx2label.values())
    
    # Plot heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(transitions, 
                xticklabels=label_names, 
                yticklabels=label_names,
                cmap='RdYlGn',
                center=0,
                annot=True,
                fmt='.2f',
                cbar_kws={'label': 'Transition Score'})
    plt.xlabel('To Tag')
    plt.ylabel('From Tag')
    plt.title('CRF Transition Matrix')
    plt.tight_layout()
    plt.show()
    
    print("\nInterpretation:")
    print("- Green (positive): Likely transition")
    print("- Red (negative): Unlikely transition")
    print("- Notice that I-X tags typically follow B-X or I-X tags of the same type")
    
except (NameError, KeyError, FileNotFoundError):
    print("Cannot visualize transitions. Train model first.")

## 5. Prediction Examples (After Training)

In [None]:
# Example: Make predictions on custom sentences
# This requires the full model to be loaded - implementation left as exercise

print("To make predictions on custom text:")
print("1. Load the trained model")
print("2. Tokenize your input")
print("3. Convert tokens to IDs using word_vocab")
print("4. Run model.predict()")
print("5. Convert predicted IDs to tags using label_vocab")

## 6. Error Analysis (After Evaluation)

In [None]:
# Read predictions file if it exists
try:
    predictions_df = pd.read_csv('../reports/predictions.txt', 
                                  sep='\t', 
                                  skip_blank_lines=False,
                                  names=['TOKEN', 'TRUE_TAG', 'PRED_TAG', 'CORRECT'])
    
    # Filter out header and separator rows
    predictions_df = predictions_df[predictions_df['TOKEN'] != 'TOKEN']
    predictions_df = predictions_df[~predictions_df['TOKEN'].str.startswith('=')]
    
    # Analyze errors
    errors = predictions_df[predictions_df['CORRECT'] == 'âœ—']
    
    print(f"Total predictions: {len(predictions_df)}")
    print(f"Errors: {len(errors)}")
    print(f"Accuracy: {(len(predictions_df) - len(errors)) / len(predictions_df) * 100:.2f}%")
    
    # Show some error examples
    print("\nSample Errors:")
    print(errors.head(20))
    
except FileNotFoundError:
    print("Predictions file not found. Run evaluation first.")

## Summary

This notebook provides:
1. Dataset statistics and visualizations
2. Sample annotated sentences
3. Vocabulary analysis
4. Model checkpoint inspection
5. CRF transition matrix visualization
6. Error analysis tools

Use this as a starting point for deeper analysis of your NER system!