In [None]:
!pip install seqeval

In [None]:
!pip install pytorch-crf -i https://pypi.tuna.tsinghua.edu.cn/simple/

In [None]:
# Named Entity Recognition with BERT + BiLSTM + CRF
# Homework 3

import os
import re
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

# Set environment variables to help with CUDA debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Force CPU or GPU
FORCE_CPU = False  # Set to False to enable GPU

# Check and install required packages
required_packages = {
    'transformers': 'transformers',
    'torchcrf': 'pytorch-crf',
}

# Check if necessary packages are installed
for module_name, package_name in required_packages.items():
    try:
        __import__(module_name)
    except ImportError:
        print(f"Package {module_name} is required. Please run:")
        print(f"pip install {package_name}")
        print("Then run this program again")
        sys.exit(1)

# Import necessary modules
from transformers import AutoModel, AutoTokenizer
from torchcrf import CRF

# Check if GPU is available
if FORCE_CPU:
    device = torch.device('cpu')
    print("Using CPU as specified")
else:
    try:
        if torch.cuda.is_available():
            # Test CUDA functionality
            test_tensor = torch.tensor([1.0]).cuda()
            test_tensor = test_tensor + 1.0
            device = torch.device('cuda')
            print(f"CUDA is available and working. Using device: {device}")
            print(f"CUDA Version: {torch.version.cuda}")
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        else:
            device = torch.device('cpu')
            print(f"CUDA not available. Using device: {device}")
    except Exception as e:
        print(f"Error initializing CUDA: {e}")
        print("Falling back to CPU")
        device = torch.device('cpu')

# Set random seed for reproducibility
try:
    torch.manual_seed(42)
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(42)
except Exception as e:
    print(f"Warning: Could not set CUDA random seed: {e}")

np.random.seed(42)

# Define dataset file paths
TRAIN_FILE = "train.txt"
VALID_FILE = "valid.txt"
TEST_FILE = "test.txt"

# Parameters
MAX_LEN = 128  # Maximum sequence length
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
EPOCHS = 10
MODEL_NAME = "bert-base-cased"  # Use standard BERT model

# Load tokenizer
print(f"Loading tokenizer for pretrained model '{MODEL_NAME}'...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Tokenizer loaded successfully")

# Function to read CoNLL format NER data
def read_conll(file_path):
    """Read file in CoNLL format with special handling for malformed lines"""
    sentences = []
    labels = []

    sentence = []
    label = []

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line == "":
                if sentence:
                    sentences.append(sentence)
                    labels.append(label)
                    sentence = []
                    label = []
            else:
                # Split by whitespace
                parts = line.split()

                # Special handling for malformed lines
                if len(parts) == 1:
                    # This might be a line with just a tag (like 'O')
                    # Skip this line, don't show warning
                    continue
                elif len(parts) >= 2:
                    # Normal case: we have at least a token and a tag
                    word = parts[0]
                    tag = parts[-1]  # Last element is the tag
                    sentence.append(word)
                    label.append(tag)

        # Add the last sentence if the file doesn't end with an empty line
        if sentence:
            sentences.append(sentence)
            labels.append(label)

    # Print sample to verify format
    if sentences:
        print(f"Data format sample (from {file_path}):")
        for i, (token, tag) in enumerate(zip(sentences[0][:5], labels[0][:5])):
            print(f"  {token} -> {tag}")

    return sentences, labels

# Read datasets
print(f"Reading training set: {TRAIN_FILE}")
train_sentences, train_labels = read_conll(TRAIN_FILE)
print(f"Reading validation set: {VALID_FILE}")
valid_sentences, valid_labels = read_conll(VALID_FILE)
print(f"Reading test set: {TEST_FILE}")
test_sentences, test_labels = read_conll(TEST_FILE)

print(f"Training set: {len(train_sentences)} sentences")
print(f"Validation set: {len(valid_sentences)} sentences")
print(f"Test set: {len(test_sentences)} sentences")

# Display a sample from training data
print("\nSample from training data:")
for token, label in zip(train_sentences[0][:10], train_labels[0][:10]):
    print(f"{token} -> {label}")

# Get unique tags
unique_tags = sorted(list(set(tag for doc in train_labels for tag in doc)))
tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
idx2tag = {idx: tag for idx, tag in enumerate(unique_tags)}

print(f"\nUnique tags: {unique_tags}")
print(f"Number of unique tags: {len(unique_tags)}")

# NER Dataset class
class NERDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_len, tag2idx):
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.tag2idx = tag2idx

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        word_labels = self.labels[idx]

        # Tokenize the sentence
        encoding = self.tokenizer(
            sentence,
            is_split_into_words=True,
            return_offsets_mapping=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Remove batch dimension added by tokenizer
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        offsets = encoding['offset_mapping'].squeeze(0)

        # Create token labels aligned with BERT's WordPiece tokenization
        labels = torch.ones(input_ids.shape, dtype=torch.long) * -100  # -100 is ignored by PyTorch's CrossEntropyLoss

        # Map word tokens to word pieces (handle subword tokenization)
        word_ids = encoding.word_ids()
        previous_word_idx = None

        for i, word_idx in enumerate(word_ids):
            # Skip special tokens
            if word_idx is None:
                continue

            # Only label the first token of a word
            if word_idx != previous_word_idx:
                # If the word index is within our word_labels range
                if word_idx < len(word_labels):
                    labels[i] = self.tag2idx.get(word_labels[word_idx], 0)  # Default to 'O' (0) if tag not in tag2idx

            previous_word_idx = word_idx

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Create datasets
train_dataset = NERDataset(train_sentences, train_labels, tokenizer, MAX_LEN, tag2idx)
valid_dataset = NERDataset(valid_sentences, valid_labels, tokenizer, MAX_LEN, tag2idx)
test_dataset = NERDataset(test_sentences, test_labels, tokenizer, MAX_LEN, tag2idx)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Define the BERT + BiLSTM + CRF model
class BERTBiLSTMCRF(nn.Module):
    def __init__(self, bert_model_name, num_tags, lstm_hidden_dim=768, lstm_layers=2, dropout=0.1):
        super(BERTBiLSTMCRF, self).__init__()

        # BERT layer
        self.bert = AutoModel.from_pretrained(bert_model_name)

        # BiLSTM layer
        self.lstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=lstm_hidden_dim // 2,  # Divide by 2 for bidirectional
            num_layers=lstm_layers,
            bidirectional=True,
            batch_first=True
        )

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Linear layer for tag prediction
        self.hidden2tag = nn.Linear(lstm_hidden_dim, num_tags)

        # CRF layer
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        # Apply BiLSTM
        lstm_output, _ = self.lstm(sequence_output)
        lstm_output = self.dropout(lstm_output)

        # Apply linear layer to get emissions for CRF
        emissions = self.hidden2tag(lstm_output)

        # If labels are provided, calculate loss, otherwise return emissions for prediction
        if labels is not None:
            # Create mask from attention_mask
            mask = attention_mask.bool()

            # Replace -100 labels with 0 (or any valid tag index) to avoid CRF errors
            # CRF will use the mask to ignore these positions
            labels_fixed = labels.clone()
            labels_fixed[labels == -100] = 0  # Replace with a valid tag index (0)

            # Calculate negative log likelihood loss for CRF
            log_likelihood = self.crf(emissions, labels_fixed, mask=mask, reduction='mean')
            return -log_likelihood
        else:
            # Decode the best path
            mask = attention_mask.bool()
            best_tags = self.crf.decode(emissions, mask=mask)
            return best_tags

# Initialize the model with error handling
try:
    print("Initializing model...")
    model = BERTBiLSTMCRF(MODEL_NAME, len(tag2idx))
    model.to(device)
    print(f"Model successfully initialized and moved to {device}")
except Exception as e:
    print(f"Error initializing model: {e}")
    raise

# Set up the optimizer with weight decay
print("Setting up optimizer...")
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]
optimizer = optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)
print("Optimizer configured successfully")

# Training and evaluation functions
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        try:
            # Clear gradients
            optimizer.zero_grad()

            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

            # Backward pass
            loss.backward()

            # Update parameters
            optimizer.step()

            # Update total loss
            total_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix({"Loss": total_loss / (progress_bar.n + 1)})
        except Exception as e:
            print(f"Error in batch training: {e}")
            import traceback
            traceback.print_exc()
            continue

    return total_loss / len(dataloader)

def evaluate(model, dataloader, device, tag2idx, idx2tag):
    model.eval()

    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            try:
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                # Get predictions
                best_paths = model(input_ids=input_ids, attention_mask=attention_mask)

                # Move to CPU for processing
                input_ids = input_ids.cpu().numpy()
                attention_mask = attention_mask.cpu().numpy()
                labels = labels.cpu().numpy()

                # Process each sequence in the batch
                for i, (path, label, mask) in enumerate(zip(best_paths, labels, attention_mask)):
                    pred_list = []
                    true_list = []

                    for j, m in enumerate(mask):
                        if m and label[j] != -100:  # Skip padding and subword tokens
                            if j < len(path):  # Make sure we don't go out of bounds
                                pred_tag = idx2tag.get(path[j], "O")  # Default to "O" if unknown
                                true_tag = idx2tag.get(label[j], "O")  # Default to "O" if unknown

                                pred_list.append(pred_tag)
                                true_list.append(true_tag)

                    predictions.append(pred_list)
                    true_labels.append(true_list)
            except Exception as e:
                print(f"Error in evaluation: {e}")
                import traceback
                traceback.print_exc()
                continue

    # Flatten the lists for evaluation (handle empty lists)
    flat_preds = []
    flat_labels = []

    for sublist in predictions:
        flat_preds.extend(sublist)

    for sublist in true_labels:
        flat_labels.extend(sublist)

    # Make sure we have equal lengths
    min_len = min(len(flat_preds), len(flat_labels))
    flat_preds = flat_preds[:min_len]
    flat_labels = flat_labels[:min_len]

    # Calculate metrics if we have predictions
    if len(flat_preds) > 0 and len(flat_labels) > 0:
        try:
            report = classification_report(flat_labels, flat_preds, digits=4)
        except Exception as e:
            print(f"Error generating classification report: {e}")
            report = "Classification report generation failed"
    else:
        report = "No valid predictions to evaluate"

    return report, predictions, true_labels

# Train the model with error handling
train_losses = []
val_reports = []
best_f1 = 0
best_model_state = None

try:
    for epoch in range(EPOCHS):
        print(f"\n{'='*20} Epoch {epoch + 1}/{EPOCHS} {'='*20}")

        try:
            # Train
            print("Starting training phase...")
            train_loss = train_epoch(model, train_loader, optimizer, device)
            train_losses.append(train_loss)

            # Evaluate on validation set
            print("Starting validation phase...")
            val_report, _, _ = evaluate(model, valid_loader, device, tag2idx, idx2tag)
            val_reports.append(val_report)

            # Extract F1 score to track best model
            try:
                # Extract weighted avg F1-score from report
                report_lines = val_report.strip().split('\n')
                weighted_avg_line = [line for line in report_lines if 'weighted avg' in line][0]
                current_f1 = float(weighted_avg_line.strip().split()[-2])

                # Save best model
                if current_f1 > best_f1:
                    best_f1 = current_f1
                    best_model_state = model.state_dict().copy()
                    print(f"New best model saved with F1: {best_f1:.4f}")
            except Exception as e:
                print(f"Error extracting F1 score: {e}")

            print(f"Training Loss: {train_loss:.4f}")
            print("Validation Report:")
            print(val_report)
            print("-" * 60)

        except Exception as e:
            print(f"Error during epoch {epoch + 1}: {e}")
            if epoch > 0:  # Only continue if we have at least one successful epoch
                print("Continuing to next epoch...")
                continue
            else:
                raise

    # Save the trained model
    if best_model_state is not None:
        print("Saving best model...")
        torch.save(best_model_state, "bert_bilstm_crf_ner_model.pt")
        # Load best model for final evaluation
        model.load_state_dict(best_model_state)
    else:
        print("Saving final model...")
        torch.save(model.state_dict(), "bert_bilstm_crf_ner_model.pt")

    print("Model saved!")

except Exception as e:
    print(f"Error during training: {e}")
    # Save checkpoint if possible
    try:
        if len(train_losses) > 0:
            print("Saving model checkpoint from last successful epoch...")
            torch.save(model.state_dict(), "bert_bilstm_crf_ner_model_checkpoint.pt")
            print("Checkpoint saved!")
    except Exception as ce:
        print(f"Could not save checkpoint: {ce}")

# Plot training loss if we have any data
if train_losses:
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, marker='o')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig('training_loss.png')
    plt.show()
else:
    print("No training loss data to plot")

# Evaluate on the test set
try:
    print("\nEvaluating on test set...")
    test_report, test_predictions, test_true_labels = evaluate(model, test_loader, device, tag2idx, idx2tag)
    print("Test Report:")
    print(test_report)
except Exception as e:
    print(f"Error during test evaluation: {e}")
    test_predictions = []
    test_true_labels = []
    test_report = "Evaluation failed"

# Save the test predictions to a file
try:
    print("\nSaving test predictions...")
    with open("test_predictions.txt", "w", encoding="utf-8") as f:
        for i, (sentence, true_labels, pred_labels) in enumerate(zip(test_sentences, test_true_labels, test_predictions)):
            for j, (word, true_label, pred_label) in enumerate(zip(sentence, true_labels, pred_labels)):
                f.write(f"{word} {true_label} {pred_label}\n")
            f.write("\n")
    print("Test predictions saved to test_predictions.txt")
except Exception as e:
    print(f"Error saving predictions: {e}")
    print("Trying alternative saving method...")
    try:
        with open("test_predictions_simple.txt", "w", encoding="utf-8") as f:
            for i, sentence in enumerate(test_sentences):
                for j, word in enumerate(sentence):
                    pred_label = "O"  # Default prediction if there was an error
                    true_label = "O"  # Default true label if there was an error

                    # Try to get actual predictions if available
                    if i < len(test_predictions) and j < len(test_predictions[i]):
                        pred_label = test_predictions[i][j]

                    # Try to get actual true labels if available
                    if i < len(test_true_labels) and j < len(test_true_labels[i]):
                        true_label = test_true_labels[i][j]

                    f.write(f"{word} {true_label} {pred_label}\n")
                f.write("\n")
        print("Simplified test predictions saved to test_predictions_simple.txt")
    except Exception as e2:
        print(f"Alternative saving also failed: {e2}")

# Troubleshooting CUDA issues
print("\n" + "="*80)
print("TROUBLESHOOTING CUDA ISSUES")
print("="*80 + "\n")
print("""
If you're encountering CUDA errors like "device-side assert triggered" or other GPU-related issues,
try the following solutions:

1. Set FORCE_CPU = True at the top of this notebook to run everything on CPU
   - This is slower but more reliable

2. Check CUDA version compatibility:
   - Run the following to check your CUDA version:
     ```
     import torch
     print(torch.version.cuda)
     ```
   - Make sure it's compatible with your PyTorch version

3. Check GPU memory:
   - You might be running out of GPU memory
   - Reduce BATCH_SIZE (e.g., from 16 to 8 or 4)
   - Reduce MAX_LEN (e.g., from 128 to 64)

4. Environment variables:
   - Try setting these environment variables before running:
     ```
     import os
     os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
     ```

5. Driver issues:
   - Update your NVIDIA drivers
   - Restart your runtime or machine

""")

# Function to generate predictions for new sentences
def predict_entities(model, tokenizer, sentences, tag2idx, idx2tag, device, max_len=128):
    model.eval()
    predictions = []

    try:
        for sentence in sentences:
            # Tokenize
            encoding = tokenizer(
                sentence,
                is_split_into_words=True,
                return_offsets_mapping=True,
                padding='max_length',
                truncation=True,
                max_length=max_len,
                return_tensors='pt'
            )

            # Move to device
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            offsets = encoding['offset_mapping'].squeeze(0).numpy()

            # Get predictions
            with torch.no_grad():
                best_path = model(input_ids=input_ids, attention_mask=attention_mask)[0]

            # Convert predictions to tags
            pred_tags = []
            prev_offset = None
            for i, (offset, pred) in enumerate(zip(offsets, best_path)):
                # Skip special tokens and padding
                if offset[0] == 0 and offset[1] != 0:  # This is the start of a word
                    if i < len(best_path):
                        pred_tags.append(idx2tag.get(pred, "O"))

            # Ensure predictions align with original tokens
            if len(pred_tags) > len(sentence):
                print(f"Warning: Prediction length mismatch: {len(pred_tags)} vs {len(sentence)}")
                pred_tags = pred_tags[:len(sentence)]

            # Pad predictions if necessary
            if len(pred_tags) < len(sentence):
                pred_tags += ['O'] * (len(sentence) - len(pred_tags))

            predictions.append(pred_tags)

    except Exception as e:
        print(f"Error during prediction: {e}")
        # Return empty predictions or dummy predictions in case of error
        predictions = [['O'] * len(sentence) for sentence in sentences]
        print("Returning default 'O' predictions due to error")

    return predictions

# Example of how to use the prediction function
example_sentences = [["John", "lives", "in", "New", "York", "and", "works", "for", "Google"]]
try:
    print("\nTesting prediction function with example sentence...")
    example_predictions = predict_entities(model, tokenizer, example_sentences, tag2idx, idx2tag, device)

    print("\nExample prediction:")
    for sentence, preds in zip(example_sentences, example_predictions):
        for token, tag in zip(sentence, preds):
            print(f"{token} -> {tag}")
except Exception as e:
    print(f"Error in example prediction: {e}")

# Generate a report
print("\nTraining Summary:")
print(f"Model: BERT + BiLSTM + CRF")
print(f"Training set size: {len(train_sentences)} sentences")
print(f"Validation set size: {len(valid_sentences)} sentences")
print(f"Test set size: {len(test_sentences)} sentences")
print(f"Number of epochs: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
if train_losses:
    print(f"Final training loss: {train_losses[-1]:.4f}")
print("\nTest Results:")
print(test_report)

# Create confusion matrix for visualization if possible
try:
    from sklearn.metrics import confusion_matrix
    import itertools

    flat_preds = [p for sublist in test_predictions for p in sublist]
    flat_labels = [l for sublist in test_true_labels for l in sublist]

    # Get unique tags from the test data
    unique_tags_test = sorted(list(set(flat_labels)))

    # Create confusion matrix
    cm = confusion_matrix(flat_labels, flat_preds, labels=unique_tags_test)

    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_tags_test, yticklabels=unique_tags_test)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.show()
    print("Confusion matrix saved to confusion_matrix.png")
except Exception as e:
    print(f"Error creating confusion matrix: {e}")