In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
try:
    from torchcrf import CRF
except ImportError:
    print("Error: torchcrf not found. Please install it using 'pip install torchcrf' in the correct environment.")
    exit(1)
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import random
import copy
import os
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Utility functions
def load_embedding_dict(path, max_words=10000):
    """Load GloVe embeddings, limiting to max_words for efficiency."""
    logger.info(f"Loading embeddings from {path}")
    embeddings = {}
    try:
        with open(path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i >= max_words:
                    break
                values = line.strip().split()
                word = values[0]
                vector = np.array(values[1:], dtype='float32')
                embeddings[word] = vector
    except FileNotFoundError:
        logger.error(f"Embedding file {path} not found.")
        exit(1)
    logger.info(f"Loaded {len(embeddings)} embeddings")
    return embeddings

def read_ner_data_from_connl(file_path):
    """Read NER data in CoNLL format."""
    logger.info(f"Reading data from {file_path}")
    words, tags = [], []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            sentence, sentence_tags = [], []
            for line in f:
                if line.strip():
                    parts = line.strip().split()
                    if len(parts) < 4:
                        logger.warning(f"Skipping malformed line in {file_path}: {line.strip()}")
                        continue
                    word, _, _, tag = parts
                    sentence.append(word)
                    sentence_tags.append(tag)
                else:
                    if sentence:
                        words.append(sentence)
                        tags.append(sentence_tags)
                        sentence, sentence_tags = [], []
            if sentence:
                words.append(sentence)
                tags.append(sentence_tags)
        logger.info(f"Loaded {len(words)} sentences from {file_path}")
        if not words:
            logger.error(f"No data loaded from {file_path}. Check file format or path.")
            exit(1)
        return words, tags
    except FileNotFoundError:
        logger.error(f"Data file {file_path} not found.")
        exit(1)

def augment_data(words, tags, glove, augment_factor=0.05):
    """Augment data by replacing words with similar GloVe embeddings."""
    logger.info("Augmenting data...")
    augmented_words, augmented_tags = copy.deepcopy(words), copy.deepcopy(tags)
    glove_words = list(glove.keys())
    frequent_words = glove_words[:min(5000, len(glove_words))]
    word_vectors = np.array([glove[w] for w in frequent_words])
    
    for i in tqdm(range(len(words)), desc="Augmenting sentences"):
        sentence = words[i]
        for j, word in enumerate(sentence):
            if random.random() < augment_factor:
                if word in glove:
                    word_vec = glove[word]
                    similarities = np.dot(word_vectors, word_vec) / (
                        np.linalg.norm(word_vectors, axis=1) * np.linalg.norm(word_vec)
                    )
                    top_idx = np.argmax(similarities)
                    if similarities[top_idx] > 0.7:
                        augmented_words[i][j] = frequent_words[top_idx]
    logger.info(f"Augmented data size: {len(augmented_words)} sentences")
    return augmented_words, augmented_tags

# Indexer class
class Indexer:
    def __init__(self, elements):
        self.element_to_index = {'<PAD>': 0, '<UNK>': 1}
        self.index_to_element = {0: '<PAD>', 1: '<UNK>'}
        for element in set(sum(elements, [])):
            if element not in self.element_to_index:
                index = len(self.element_to_index)
                self.element_to_index[element] = index
                self.index_to_element[index] = element
    
    def size(self):
        return len(self.element_to_index)
    
    def elements_to_index(self, elements):
        return [[self.element_to_index.get(e, 1) for e in seq] for seq in elements]
    
    def get_element_to_index_dict(self):
        return self.element_to_index

# Embedding fabric
class EmbeddingFabric:
    @staticmethod
    def get_embedding_layer(word_indexer, glove, strategy):
        """Create embedding layer based on strategy."""
        logger.info(f"Creating embedding layer for strategy {strategy}")
        vocab_size = word_indexer.size()
        embedding_matrix = np.zeros((vocab_size, 100))
        for word, idx in word_indexer.element_to_index.items():
            if strategy == 'strategy_a':
                embedding_matrix[idx] = glove.get(word, np.zeros(100))
            elif strategy == 'strategy_b':
                embedding_matrix[idx] = glove.get(word, np.random.normal(0, 0.1, 100))
            else:
                embedding_matrix[idx] = glove.get(word, np.mean(list(glove.values()), axis=0))
        embedding_layer = nn.Embedding(vocab_size, 100)
        embedding_layer.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        embedding_layer.weight.requires_grad = (strategy != 'strategy_a')
        return embedding_layer

# Metrics handler
class MetricsHandler:
    def __init__(self, labels):
        self.labels = labels
        self.metrics_dict = defaultdict(list)
    
    def update(self, predictions, true_vals):
        """Update metrics for a batch."""
        for pred, true in zip(predictions, true_vals):
            correct = sum(1 for p, t in zip(pred, true) if p == t and t != 0)
            pred_pos = sum(1 for p in pred if p != 0)
            true_pos = sum(1 for t in true if t != 0)
            precision = correct / pred_pos if pred_pos else 0
            recall = correct / true_pos if true_pos else 0
            self.metrics_dict['precision'].append(precision)
            self.metrics_dict['recall'].append(recall)
            self.metrics_dict['f1'].append(2 * precision * recall / (precision + recall) if precision + recall else 0)
            self.metrics_dict['f0.5'].append((1 + 0.5**2) * precision * recall / (0.5**2 * precision + recall) if precision + recall else 0)
    
    def collect(self):
        """Aggregate metrics."""
        for metric in self.metrics_dict:
            if self.metrics_dict[metric]:
                self.metrics_dict[metric] = [np.mean(self.metrics_dict[metric])]
            else:
                self.metrics_dict[metric] = [0.0]
    
    def get_metrics(self):
        return self.metrics_dict

# BiLSTM-CRF Model
class BiLSTMCRFTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, tagset_size, embedding_layer, dropout=0.3):
        super(BiLSTMCRFTagger, self).__init__()
        self.embedding = embedding_layer
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=2, 
                             bidirectional=True, batch_first=True, dropout=dropout)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)
    
    def forward(self, sentence, tags=None, mask=None):
        embeds = self.embedding(sentence)
        lstm_out, _ = self.bilstm(embeds)
        emissions = self.hidden2tag(lstm_out)
        if tags is not None:
            return -self.crf(emissions, tags, mask=mask)
        return self.crf.decode(emissions, mask=mask)

# Training function
def train_model(model, optimizer, scheduler, data_dict, batch_size, words_indexer, tags_indexer, 
                metric_handler, valid_metric, num_epochs=5, patience=2):
    """Train the model with early stopping and checkpointing."""
    logger.info("Starting training...")
    model.to(device)
    train_data, train_tags = data_dict['train']
    valid_data, valid_tags = data_dict['dev']
    
    # Validate data size
    if not train_data:
        logger.error("Training data is empty. Check data loading.")
        exit(1)
    if len(train_data) < batch_size:
        logger.warning(f"Training data size ({len(train_data)}) is smaller than batch size ({batch_size}). Adjusting batch size.")
        batch_size = max(1, len(train_data) // 2)
    
    num_batches = max(1, len(train_data) // batch_size)
    logger.info(f"Training with {len(train_data)} samples, {num_batches} batches per epoch")
    
    # Log sequence lengths
    seq_lengths = [len(seq) for seq in train_data]
    logger.info(f"Sequence lengths: min={min(seq_lengths)}, max={max(seq_lengths)}, mean={np.mean(seq_lengths):.2f}")
    
    losses = []
    best_f1 = 0
    patience_counter = 0
    best_model_state = None
    checkpoint_dir = 'checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        metric_handler.metrics_dict.clear()
        train_data_idx = words_indexer.elements_to_index(train_data)
        train_tags_idx = tags_indexer.elements_to_index(train_tags)
        
        # Shuffle data
        indices = list(range(len(train_data)))
        random.shuffle(indices)
        train_data_idx = [train_data_idx[i] for i in indices]
        train_tags_idx = [train_tags_idx[i] for i in indices]
        
        for i in tqdm(range(0, len(train_data), batch_size), desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch_indices = indices[i:i+batch_size]
            batch_data = [torch.tensor(train_data_idx[j], dtype=torch.long) for j in batch_indices]
            batch_tags = [torch.tensor(train_tags_idx[j], dtype=torch.long) for j in batch_indices]
            
            # Pad sequences
            batch_data_padded = pad_sequence(batch_data, batch_first=True, padding_value=0).to(device)
            batch_tags_padded = pad_sequence(batch_tags, batch_first=True, padding_value=0).to(device)
            mask = (batch_data_padded != 0).to(device)
            
            optimizer.zero_grad()
            loss = model(batch_data_padded, batch_tags_padded, mask=mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        
        scheduler.step()
        losses.append(total_loss / num_batches if num_batches > 0 else total_loss)
        
        # Training metrics
        model.eval()
        with torch.no_grad():
            for data, tags in zip(train_data_idx, train_tags_idx):
                data_tensor = torch.tensor([data], dtype=torch.long).to(device)
                mask = (data_tensor != 0).to(device)
                pred = model(data_tensor, mask=mask)[0]
                metric_handler.update([pred], [tags])
            metric_handler.collect()
        
        # Validation metrics
        valid_metric.metrics_dict.clear()
        with torch.no_grad():
            for data, tags in zip(words_indexer.elements_to_index(valid_data), 
                                tags_indexer.elements_to_index(valid_tags)):
                data_tensor = torch.tensor([data], dtype=torch.long).to(device)
                mask = (data_tensor != 0).to(device)
                pred = model(data_tensor, mask=mask)[0]
                valid_metric.update([pred], [tags])
            valid_metric.collect()
        
        current_f1 = valid_metric.metrics_dict['f1'][-1]
        logger.info(f"Epoch {epoch+1}/{num_epochs}")
        logger.info(f"Loss: {losses[-1]:.4f}")
        logger.info(f"Train - Precision: {metric_handler.metrics_dict['precision'][-1]:.4f}, "
                   f"Recall: {metric_handler.metrics_dict['recall'][-1]:.4f}, "
                   f"F1: {metric_handler.metrics_dict['f1'][-1]:.4f}")
        logger.info(f"Valid - Precision: {valid_metric.metrics_dict['precision'][-1]:.4f}, "
                   f"Recall: {valid_metric.metrics_dict['recall'][-1]:.4f}, "
                   f"F1: {valid_metric.metrics_dict['f1'][-1]:.4f}")
        
        # Checkpointing
        checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch+1}.pt')
        torch.save(model.state_dict(), checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path}")
        
        # Early stopping
        if current_f1 > best_f1:
            best_f1 = current_f1
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch+1}")
                model.load_state_dict(best_model_state)
                break
    
    return model, metric_handler, valid_metric, losses

# Visualization function
def build_training_visualization(name, train_metrics, losses, valid_metrics, output_path):
    """Create training and validation metrics visualization."""
    logger.info(f"Saving visualization to {output_path}")
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.plot(losses, label='Loss', color='#1f77b4')
    ax1.set_title(f'{name} - Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    for metric in ['precision', 'recall', 'f1']:
        ax2.plot(train_metrics[metric], label=f'Train {metric}', color=f'#{hash(metric) % 0xFFFFFF:06x}')
        ax2.plot(valid_metrics[metric], label=f'Valid {metric}', linestyle='--')
    ax2.set_title(f'{name} - Metrics')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Score')
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

# Main execution
if __name__ == "__main__":
    # File paths
    TRAIN_PATH = 'train.txt'
    DEV_PATH = 'dev.txt'
    TEST_PATH = 'test.txt'
    EMBEDDINGS_PATH = 'glove.6B.100d.txt'

    # Load data
    logger.info("Loading data...")
    glove = load_embedding_dict(EMBEDDINGS_PATH, max_words=10000)
    words, tags = read_ner_data_from_connl(TRAIN_PATH)
    val_words, val_tags = read_ner_data_from_connl(DEV_PATH)
    test_words, test_tags = read_ner_data_from_connl(TEST_PATH)
    
    # Data augmentation
    logger.info("Augmenting data...")
    aug_words, aug_tags = augment_data(words, tags, glove, augment_factor=0.05)
    words.extend(aug_words)
    tags.extend(aug_tags)
    
    data_dict = {
        'train': (words, tags),
        'dev': (val_words, val_tags),
        'test': (test_words, test_tags)
    }
    
    words_indexer = Indexer(words)
    tags_indexer = Indexer(tags)
    EMBEDDING_DIM = 100
    HIDDEN_DIM = 200
    
    # Simplified hyperparameter search for faster execution
    learning_rate = 0.001
    dropout = 0.3
    
    best_f1 = 0
    best_model = None
    best_strategy = None
    
    for strat in ['a', 'b', 'c']:
        strategy = f"strategy_{strat}"
        logger.info(f"Training {strategy} with lr={learning_rate}, dropout={dropout}")
        
        model = BiLSTMCRFTagger(
            embedding_dim=EMBEDDING_DIM,
            hidden_dim=HIDDEN_DIM,
            tagset_size=tags_indexer.size(),
            embedding_layer=EmbeddingFabric.get_embedding_layer(words_indexer, glove, strategy),
            dropout=dropout
        ).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
        metric_handler = MetricsHandler(list(tags_indexer.get_element_to_index_dict().values()))
        valid_metric = MetricsHandler(list(tags_indexer.get_element_to_index_dict().values()))
        
        model, train_metrics, valid_metrics, losses = train_model(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            data_dict=data_dict,
            batch_size=64,
            words_indexer=words_indexer,
            tags_indexer=tags_indexer,
            metric_handler=metric_handler,
            valid_metric=valid_metric,
            num_epochs=5,
            patience=2
        )
        
        # Test set evaluation
        test_metrics = MetricsHandler(list(tags_indexer.get_element_to_index_dict().values()))
        model.eval()
        with torch.no_grad():
            for data, tags in zip(words_indexer.elements_to_index(test_words), 
                                tags_indexer.elements_to_index(test_tags)):
                data_tensor = torch.tensor([data], dtype=torch.long).to(device)
                mask = (data_tensor != 0).to(device)
                pred = model(data_tensor, mask=mask)[0]
                test_metrics.update([pred], [tags])
            test_metrics.collect()
        
        f1_score = test_metrics.metrics_dict['f1'][-1]
        logger.info(f"{strategy} results on test set:")
        for metric in test_metrics.metrics_dict.keys():
            logger.info(f"{metric} - {test_metrics.metrics_dict[metric][-1]:.4f}")
        
        if f1_score > best_f1:
            best_f1 = f1_score
            best_model = copy.deepcopy(model)
            best_strategy = strategy
        
        build_training_visualization(
            f"{strategy}_lr{learning_rate}_drop{dropout}",
            train_metrics.get_metrics(),
            losses,
            valid_metrics.get_metrics(),
            f'{strategy}_lr{learning_rate}_drop{dropout}.png'
        )
    
    logger.info(f"Best model: {best_strategy} with F1 score: {best_f1:.4f}")
    torch.save(best_model.state_dict(), 'best_model.pt')
    logger.info("Saved best model to best_model.pt")

2025-08-14 20:39:50,192 - INFO - Using device: cuda
2025-08-14 20:39:50,196 - INFO - Loading data...
2025-08-14 20:39:50,197 - INFO - Loading embeddings from glove.6B.100d.txt
2025-08-14 20:39:50,375 - INFO - Loaded 10000 embeddings
2025-08-14 20:39:50,378 - INFO - Reading data from train.txt
2025-08-14 20:39:50,464 - INFO - Loaded 14987 sentences from train.txt
2025-08-14 20:39:50,465 - INFO - Reading data from dev.txt
2025-08-14 20:39:50,488 - INFO - Loaded 3466 sentences from dev.txt
2025-08-14 20:39:50,489 - INFO - Reading data from test.txt
2025-08-14 20:39:50,516 - INFO - Loaded 3684 sentences from test.txt
2025-08-14 20:39:50,517 - INFO - Augmenting data...
2025-08-14 20:39:50,518 - INFO - Augmenting data...
Augmenting sentences: 100%|█████████████████████████████| 14987/14987 [00:02<00:00, 6749.96it/s]
2025-08-14 20:39:52,894 - INFO - Augmented data size: 14987 sentences
2025-08-14 20:40:39,463 - INFO - Training strategy_a with lr=0.001, dropout=0.3
2025-08-14 20:40:39,463 - IN