# BiLSTM Model Training Walkthrough

This notebook explains our BiLSTM (Bidirectional Long Short-Term Memory) model training process for Myanmar news sentiment classification.

## Environment Setup
**Conda Environment:** nlp  
**Purpose:** Train deep learning model for 3-class sentiment classification

## Model Architecture Overview
```
Input (Myanmar Tokens) → Embedding → BiLSTM → Dense → Output (3 Classes)
```

**Classes:** 
- 0: neutral (DVB - opposition perspective)
- 1: red (Khitthit - critical perspective)
- 2: green (Myawady - government perspective)

## 1. Data Preparation for Training

### Loading and Preprocessing Training Data
Our BiLSTM model requires numerical sequences as input, so we convert Myanmar text tokens into numerical representations.

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import seaborn as sns

class TrainingDataProcessor:
    """
    Prepare Myanmar text data for BiLSTM training.
    
    Key Functions:
    - Convert tokens to numerical sequences
    - Handle vocabulary management
    - Create padded sequences for batch processing
    - Generate training/validation splits
    """
    
    def __init__(self, max_vocab_size=10000, max_sequence_length=500):
        """
        Initialize data processor.
        
        Args:
            max_vocab_size (int): Maximum vocabulary size (most frequent tokens)
            max_sequence_length (int): Maximum sequence length for padding
        
        Why these defaults?
        - 10k vocab: Covers most Myanmar tokens while keeping model manageable
        - 500 length: Accommodates full articles while controlling memory usage
        """
        self.max_vocab_size = max_vocab_size
        self.max_sequence_length = max_sequence_length
        self.tokenizer = None
        self.label_mapping = {0: 'neutral', 1: 'red', 2: 'green'}
    
    def load_training_data(self, csv_path):
        """
        Load labeled training data from CSV.
        
        Expected CSV format:
        - full_text: Complete article text (title + content)
        - tokens: Space-separated Myanmar tokens from MyWord
        - label_numeric: Class labels (0, 1, 2)
        - source: Original news source
        
        Returns:
            pandas.DataFrame: Loaded training data
        """
        df = pd.read_csv(csv_path)
        
        print(f"📊 Training Data Summary:")
        print(f"   Total articles: {len(df)}")
        print(f"   Label distribution:")
        for label, count in df['label_numeric'].value_counts().sort_index().items():
            label_name = self.label_mapping[label]
            print(f"     {label} ({label_name}): {count} articles")
        
        print(f"\n📈 Text Statistics:")
        print(f"   Avg token count: {df['token_count'].mean():.1f}")
        print(f"   Max token count: {df['token_count'].max()}")
        print(f"   Min token count: {df['token_count'].min()}")
        
        return df
    
    def create_tokenizer(self, texts):
        """
        Create and fit Keras tokenizer on Myanmar text.
        
        Args:
            texts (list): List of tokenized text strings
        
        Returns:
            Tokenizer: Fitted tokenizer for text-to-sequence conversion
        """
        # Use pre-tokenized text (space-separated Myanmar tokens)
        self.tokenizer = Tokenizer(
            num_words=self.max_vocab_size,
            oov_token='<UNK>',  # Out-of-vocabulary token
            filters='',  # Don't filter anything - tokens are pre-processed
            lower=False,  # Preserve Myanmar script case
            split=' '  # Split on spaces (tokens already separated)
        )
        
        self.tokenizer.fit_on_texts(texts)
        
        print(f"\n🔤 Tokenizer Statistics:")
        print(f"   Vocabulary size: {len(self.tokenizer.word_index)}")
        print(f"   Most common tokens: {list(self.tokenizer.word_counts.items())[:5]}")
        
        return self.tokenizer
    
    def prepare_sequences(self, df):
        """
        Convert text to padded numerical sequences.
        
        Process:
        1. Convert tokens to sequences of integers
        2. Pad sequences to uniform length
        3. Convert labels to categorical format
        
        Args:
            df (DataFrame): Training data with tokens and labels
        
        Returns:
            tuple: (X_sequences, y_categorical)
        """
        # Use tokenized text (space-separated tokens)
        texts = df['tokens'].fillna('').astype(str).tolist()
        
        # Create tokenizer if not exists
        if self.tokenizer is None:
            self.create_tokenizer(texts)
        
        # Convert texts to sequences
        sequences = self.tokenizer.texts_to_sequences(texts)
        
        # Pad sequences to uniform length
        X = pad_sequences(
            sequences, 
            maxlen=self.max_sequence_length,
            padding='post',  # Pad at the end
            truncating='post'  # Truncate at the end if too long
        )
        
        # Convert labels to categorical (one-hot encoding)
        y = to_categorical(df['label_numeric'].values, num_classes=3)
        
        print(f"\n📐 Sequence Preparation:")
        print(f"   Input shape: {X.shape}")
        print(f"   Output shape: {y.shape}")
        print(f"   Sequence length: {self.max_sequence_length}")
        print(f"   Vocabulary size: {min(len(self.tokenizer.word_index), self.max_vocab_size)}")
        
        return X, y
    
    def calculate_class_weights(self, y_labels):
        """
        Calculate class weights for handling imbalanced data.
        
        Why class weights?
        - Our dataset might have unequal representation of sources
        - Helps model learn minority classes better
        - Prevents bias toward majority class
        
        Args:
            y_labels (array): Numeric labels (not one-hot)
        
        Returns:
            dict: Class weight mapping
        """
        class_weights = compute_class_weight(
            'balanced',
            classes=np.unique(y_labels),
            y=y_labels
        )
        
        class_weight_dict = dict(enumerate(class_weights))
        
        print(f"\n⚖️  Class Weights:")
        for class_idx, weight in class_weight_dict.items():
            label_name = self.label_mapping[class_idx]
            print(f"   {class_idx} ({label_name}): {weight:.3f}")
        
        return class_weight_dict

print("✅ Training data processor implementation complete")

## 2. BiLSTM Model Architecture

### Why BiLSTM for Myanmar News Classification?
- **Sequential Nature:** News articles have temporal structure
- **Context Understanding:** BiLSTM captures both past and future context
- **Myanmar Language:** Handles variable-length Myanmar word sequences
- **Sentiment Analysis:** Effective for opinion and bias detection

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Bidirectional, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

class BiLSTMModel:
    """
    Bidirectional LSTM model for Myanmar news classification.
    
    Architecture:
    1. Embedding Layer: Convert token IDs to dense vectors
    2. BiLSTM Layer: Process sequences bidirectionally  
    3. Dropout: Prevent overfitting
    4. Dense Layer: Final classification
    5. Softmax: 3-class probability distribution
    """
    
    def __init__(self, vocab_size, embedding_dim=128, lstm_units=64, max_length=500):
        """
        Initialize BiLSTM model architecture.
        
        Args:
            vocab_size (int): Size of vocabulary
            embedding_dim (int): Dimension of word embeddings
            lstm_units (int): Number of LSTM units
            max_length (int): Maximum sequence length
        
        Design Decisions:
        - embedding_dim=128: Balance between representation and efficiency
        - lstm_units=64: Sufficient for Myanmar text complexity
        - Bidirectional: Captures context from both directions
        """
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.lstm_units = lstm_units
        self.max_length = max_length
        self.model = None
        self.history = None
    
    def build_model(self):
        """
        Construct the BiLSTM model architecture.
        
        Layer-by-layer explanation:
        1. Embedding: Maps token IDs → dense vectors (learns word representations)
        2. BiLSTM: Processes sequences forward + backward (captures full context)
        3. Dropout: Randomly zeroes neurons (prevents overfitting)
        4. Dense: Fully connected layer (final decision making)
        5. Activation: Softmax for 3-class probability
        
        Returns:
            keras.Model: Compiled BiLSTM model
        """
        model = Sequential([
            # Layer 1: Embedding (token ID → dense vector)
            Embedding(
                input_dim=self.vocab_size + 1,  # +1 for padding token
                output_dim=self.embedding_dim,
                input_length=self.max_length,
                mask_zero=True,  # Ignore padding tokens
                name='embedding'
            ),
            
            # Layer 2: Bidirectional LSTM (sequence processing)
            Bidirectional(
                LSTM(
                    self.lstm_units,
                    dropout=0.3,  # Input dropout
                    recurrent_dropout=0.3,  # Recurrent dropout
                    return_sequences=False,  # Only return final output
                ),
                name='bilstm'
            ),
            
            # Layer 3: Dropout (regularization)
            Dropout(0.5, name='dropout'),
            
            # Layer 4: Dense classification layer
            Dense(
                32,  # Hidden units
                activation='relu',
                name='dense_hidden'
            ),
            
            # Layer 5: Output layer (3 classes)
            Dense(
                3,  # neutral, red, green
                activation='softmax',
                name='output'
            )
        ])
        
        # Compile model with appropriate loss and metrics
        model.compile(
            optimizer=Adam(learning_rate=0.001),  # Adaptive learning rate
            loss='categorical_crossentropy',  # Multi-class classification
            metrics=['accuracy', 'precision', 'recall']
        )
        
        self.model = model
        
        print(f"\n🏗️  BiLSTM Model Architecture:")
        print(f"   Vocabulary size: {self.vocab_size}")
        print(f"   Embedding dimension: {self.embedding_dim}")
        print(f"   LSTM units: {self.lstm_units} (×2 for bidirectional)")
        print(f"   Max sequence length: {self.max_length}")
        print(f"   Total parameters: {model.count_params():,}")
        
        return model
    
    def train_model(self, X_train, y_train, X_val, y_val, 
                   class_weights=None, epochs=50, batch_size=32):
        """
        Train the BiLSTM model.
        
        Training Strategy:
        - Early stopping: Prevent overfitting
        - Model checkpointing: Save best model
        - Class weights: Handle imbalanced data
        - Validation monitoring: Track generalization
        
        Args:
            X_train, y_train: Training data
            X_val, y_val: Validation data
            class_weights: Dict of class weights
            epochs: Maximum training epochs
            batch_size: Training batch size
        
        Returns:
            History: Training history for analysis
        """
        if self.model is None:
            raise ValueError("Model not built. Call build_model() first.")
        
        # Setup callbacks for training control
        callbacks = [
            # Early stopping: Stop if validation loss doesn't improve
            EarlyStopping(
                monitor='val_loss',
                patience=10,  # Wait 10 epochs for improvement
                restore_best_weights=True,
                verbose=1
            ),
            
            # Model checkpoint: Save best model
            ModelCheckpoint(
                'best_bilstm_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            )
        ]
        
        print(f"\n🚀 Starting BiLSTM Training:")
        print(f"   Training samples: {len(X_train)}")
        print(f"   Validation samples: {len(X_val)}")
        print(f"   Batch size: {batch_size}")
        print(f"   Max epochs: {epochs}")
        print(f"   Early stopping patience: 10")
        
        # Train the model
        self.history = self.model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            class_weight=class_weights,
            callbacks=callbacks,
            verbose=1
        )
        
        return self.history
    
    def evaluate_model(self, X_test, y_test):
        """
        Evaluate trained model performance.
        
        Metrics:
        - Overall accuracy
        - Per-class precision, recall, F1
        - Confusion matrix
        
        Args:
            X_test, y_test: Test data
        
        Returns:
            dict: Evaluation results
        """
        if self.model is None:
            raise ValueError("Model not trained. Call train_model() first.")
        
        # Get predictions
        y_pred = self.model.predict(X_test)
        
        # Convert from categorical to class indices
        y_true_labels = np.argmax(y_test, axis=1)
        y_pred_labels = np.argmax(y_pred, axis=1)
        
        # Classification report
        target_names = ['neutral', 'red', 'green']
        class_report = classification_report(
            y_true_labels, y_pred_labels,
            target_names=target_names,
            output_dict=True
        )
        
        # Confusion matrix
        conf_matrix = confusion_matrix(y_true_labels, y_pred_labels)
        
        print(f"\n📊 Model Evaluation Results:")
        print(f"   Overall Accuracy: {class_report['accuracy']:.3f}")
        print(f"   Macro F1-Score: {class_report['macro avg']['f1-score']:.3f}")
        print(f"   Weighted F1-Score: {class_report['weighted avg']['f1-score']:.3f}")
        
        print(f"\n📈 Per-Class Performance:")
        for i, class_name in enumerate(target_names):
            metrics = class_report[class_name]
            print(f"   {class_name}: P={metrics['precision']:.3f}, "
                  f"R={metrics['recall']:.3f}, F1={metrics['f1-score']:.3f}")
        
        return {
            'classification_report': class_report,
            'confusion_matrix': conf_matrix,
            'predictions': y_pred,
            'true_labels': y_true_labels,
            'pred_labels': y_pred_labels
        }

print("✅ BiLSTM model implementation complete")

## 3. Training Process and Monitoring

### Complete Training Pipeline
Shows the end-to-end training workflow with monitoring and validation.

In [None]:
def train_myanmar_news_classifier(csv_path, model_output_dir):
    """
    Complete training pipeline for Myanmar news classification.
    
    Pipeline Steps:
    1. Load and analyze training data
    2. Prepare sequences for neural network
    3. Create train/validation split
    4. Build and configure BiLSTM model
    5. Train with monitoring and callbacks
    6. Evaluate performance
    7. Save model and artifacts
    
    Args:
        csv_path (str): Path to labeled training CSV
        model_output_dir (str): Directory to save model and reports
    
    Returns:
        tuple: (trained_model, evaluation_results)
    """
    import os
    from datetime import datetime
    import pickle
    
    # Create output directory
    os.makedirs(model_output_dir, exist_ok=True)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    print(f"🇲🇲 Myanmar News Classification Training")
    print(f"   Start time: {datetime.now()}")
    print(f"   Data source: {csv_path}")
    print(f"   Output directory: {model_output_dir}")
    
    # Step 1: Load and prepare data
    print(f"\n📂 Step 1: Loading training data...")
    processor = TrainingDataProcessor(max_vocab_size=10000, max_sequence_length=500)
    df = processor.load_training_data(csv_path)
    
    # Step 2: Create sequences
    print(f"\n🔄 Step 2: Converting text to sequences...")
    X, y = processor.prepare_sequences(df)
    
    # Step 3: Train/validation split (stratified to maintain class balance)
    print(f"\n📊 Step 3: Creating train/validation split...")
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, stratify=df['label_numeric'], random_state=42
    )
    
    print(f"   Training set: {len(X_train)} samples")
    print(f"   Validation set: {len(X_val)} samples")
    
    # Step 4: Calculate class weights for balanced training
    class_weights = processor.calculate_class_weights(df['label_numeric'].values)
    
    # Step 5: Build BiLSTM model
    print(f"\n🏗️  Step 4: Building BiLSTM model...")
    vocab_size = min(len(processor.tokenizer.word_index), processor.max_vocab_size)
    
    model = BiLSTMModel(
        vocab_size=vocab_size,
        embedding_dim=128,
        lstm_units=64,
        max_length=processor.max_sequence_length
    )
    
    compiled_model = model.build_model()
    
    # Step 6: Train the model
    print(f"\n🚀 Step 5: Training BiLSTM model...")
    history = model.train_model(
        X_train, y_train,
        X_val, y_val,
        class_weights=class_weights,
        epochs=50,
        batch_size=32
    )
    
    # Step 7: Evaluate performance
    print(f"\n📊 Step 6: Evaluating model performance...")
    evaluation_results = model.evaluate_model(X_val, y_val)
    
    # Step 8: Save model and artifacts
    print(f"\n💾 Step 7: Saving model and artifacts...")
    
    # Save trained model
    model_path = f"{model_output_dir}/bilstm_model_{timestamp}.h5"
    compiled_model.save(model_path)
    
    # Save tokenizer for future predictions
    tokenizer_path = f"{model_output_dir}/tokenizer_{timestamp}.pickle"
    with open(tokenizer_path, 'wb') as f:
        pickle.dump(processor.tokenizer, f)
    
    # Save model parameters
    params_path = f"{model_output_dir}/model_params_{timestamp}.pickle"
    model_params = {
        'vocab_size': vocab_size,
        'embedding_dim': model.embedding_dim,
        'lstm_units': model.lstm_units,
        'max_length': model.max_length,
        'label_mapping': processor.label_mapping
    }
    with open(params_path, 'wb') as f:
        pickle.dump(model_params, f)
    
    # Generate training report
    create_training_report(
        model_output_dir, timestamp, history, evaluation_results, 
        model_params, df, class_weights
    )
    
    print(f"\n✅ Training Complete!")
    print(f"   Model saved: {model_path}")
    print(f"   Tokenizer saved: {tokenizer_path}")
    print(f"   Parameters saved: {params_path}")
    print(f"   Final validation accuracy: {evaluation_results['classification_report']['accuracy']:.3f}")
    
    return compiled_model, evaluation_results

def create_training_report(output_dir, timestamp, history, evaluation, 
                         model_params, df, class_weights):
    """
    Generate comprehensive training report with visualizations.
    
    Creates:
    1. Training history plots (loss, accuracy)
    2. Confusion matrix visualization
    3. Detailed text report
    4. Model architecture summary
    """
    import json
    
    report_dir = f"{output_dir}/training_report_{timestamp}"
    os.makedirs(report_dir, exist_ok=True)
    
    # 1. Plot training history
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Loss plot
    ax1.plot(history.history['loss'], label='Training Loss')
    ax1.plot(history.history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss During Training')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    # Accuracy plot
    ax2.plot(history.history['accuracy'], label='Training Accuracy')
    ax2.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax2.set_title('Model Accuracy During Training')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig(f"{report_dir}/training_history.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        evaluation['confusion_matrix'],
        annot=True, fmt='d',
        xticklabels=['neutral', 'red', 'green'],
        yticklabels=['neutral', 'red', 'green'],
        cmap='Blues'
    )
    plt.title('Confusion Matrix - Myanmar News Classification')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig(f"{report_dir}/confusion_matrix.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Create detailed text report
    report = {
        'training_info': {
            'timestamp': timestamp,
            'model_type': 'Bidirectional LSTM',
            'dataset_size': len(df),
            'training_epochs': len(history.history['loss']),
            'early_stopping': len(history.history['loss']) < 50
        },
        'model_architecture': model_params,
        'class_weights': class_weights,
        'performance_metrics': evaluation['classification_report'],
        'final_metrics': {
            'final_train_loss': history.history['loss'][-1],
            'final_val_loss': history.history['val_loss'][-1],
            'final_train_acc': history.history['accuracy'][-1],
            'final_val_acc': history.history['val_accuracy'][-1],
            'best_val_acc': max(history.history['val_accuracy'])
        }
    }
    
    # Save as JSON
    with open(f"{report_dir}/training_report.json", 'w', encoding='utf-8') as f:
        json.dump(report, f, indent=2, ensure_ascii=False)
    
    print(f"   📋 Training report saved: {report_dir}/")

print("✅ Training pipeline implementation complete")

## 4. Model Performance Analysis

### Key Performance Metrics

**Why These Metrics Matter:**
- **Accuracy:** Overall correctness across all classes
- **Precision:** How many predicted positives are actually positive
- **Recall:** How many actual positives are correctly identified
- **F1-Score:** Harmonic mean of precision and recall
- **Confusion Matrix:** Detailed per-class performance breakdown

**Expected Performance:**
- Target accuracy: >75% (3-class classification)
- Balanced performance across all classes
- Low false positive rate between similar classes

In [None]:
def analyze_model_predictions(model, X_test, y_test, tokenizer, test_df):
    """
    Detailed analysis of model predictions for interpretation.
    
    Analysis includes:
    - Confidence distribution
    - Misclassification patterns  
    - Feature importance (attention weights)
    - Sample predictions with explanations
    
    Args:
        model: Trained BiLSTM model
        X_test, y_test: Test data
        tokenizer: Fitted tokenizer
        test_df: Original test DataFrame
    
    Returns:
        dict: Analysis results
    """
    # Get predictions with confidence scores
    y_pred_proba = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred_proba, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)
    
    # Calculate prediction confidence
    prediction_confidence = np.max(y_pred_proba, axis=1)
    
    # Identify misclassifications
    misclassified_indices = np.where(y_pred_classes != y_true_classes)[0]
    correct_indices = np.where(y_pred_classes == y_true_classes)[0]
    
    print(f"\n🔍 Prediction Analysis:")
    print(f"   Total predictions: {len(y_pred_classes)}")
    print(f"   Correct predictions: {len(correct_indices)} ({len(correct_indices)/len(y_pred_classes)*100:.1f}%)")
    print(f"   Misclassifications: {len(misclassified_indices)} ({len(misclassified_indices)/len(y_pred_classes)*100:.1f}%)")
    print(f"   Average confidence: {np.mean(prediction_confidence):.3f}")
    
    # Confidence distribution by correctness
    correct_confidence = prediction_confidence[correct_indices]
    incorrect_confidence = prediction_confidence[misclassified_indices]
    
    print(f"\n📊 Confidence Analysis:")
    print(f"   Correct predictions confidence: {np.mean(correct_confidence):.3f} ± {np.std(correct_confidence):.3f}")
    if len(incorrect_confidence) > 0:
        print(f"   Incorrect predictions confidence: {np.mean(incorrect_confidence):.3f} ± {np.std(incorrect_confidence):.3f}")
    
    # Analyze misclassification patterns
    if len(misclassified_indices) > 0:
        print(f"\n❌ Misclassification Patterns:")
        label_names = ['neutral', 'red', 'green']
        
        for true_class in [0, 1, 2]:
            for pred_class in [0, 1, 2]:
                if true_class != pred_class:
                    pattern_count = np.sum(
                        (y_true_classes[misclassified_indices] == true_class) &
                        (y_pred_classes[misclassified_indices] == pred_class)
                    )
                    if pattern_count > 0:
                        print(f"   {label_names[true_class]} → {label_names[pred_class]}: {pattern_count} cases")
    
    # Sample analysis: Show best and worst predictions
    print(f"\n🎯 Sample Predictions:")
    
    # Best confident correct predictions
    if len(correct_indices) > 0:
        best_correct_idx = correct_indices[np.argmax(correct_confidence)]
        print(f"   Best confident correct prediction:")
        print(f"     True: {['neutral', 'red', 'green'][y_true_classes[best_correct_idx]]}")
        print(f"     Predicted: {['neutral', 'red', 'green'][y_pred_classes[best_correct_idx]]}")
        print(f"     Confidence: {prediction_confidence[best_correct_idx]:.3f}")
    
    # Worst misclassifications (high confidence but wrong)
    if len(misclassified_indices) > 0:
        worst_misclass_idx = misclassified_indices[np.argmax(incorrect_confidence)]
        print(f"   Worst misclassification (high confidence, wrong):")
        print(f"     True: {['neutral', 'red', 'green'][y_true_classes[worst_misclass_idx]]}")
        print(f"     Predicted: {['neutral', 'red', 'green'][y_pred_classes[worst_misclass_idx]]}")
        print(f"     Confidence: {prediction_confidence[worst_misclass_idx]:.3f}")
    
    return {
        'prediction_confidence': prediction_confidence,
        'correct_indices': correct_indices,
        'misclassified_indices': misclassified_indices,
        'confidence_stats': {
            'mean_confidence': np.mean(prediction_confidence),
            'correct_confidence': np.mean(correct_confidence) if len(correct_confidence) > 0 else 0,
            'incorrect_confidence': np.mean(incorrect_confidence) if len(incorrect_confidence) > 0 else 0
        }
    }

print("✅ Model analysis implementation complete")

## 5. Training Best Practices and Optimization

### Hyperparameter Tuning Strategy

**Key Parameters to Optimize:**
1. **Embedding Dimension (64-256):** Balance between representation and efficiency
2. **LSTM Units (32-128):** Network capacity vs. overfitting risk
3. **Learning Rate (0.0001-0.01):** Convergence speed vs. stability
4. **Batch Size (16-64):** Memory usage vs. gradient stability
5. **Dropout Rate (0.2-0.6):** Regularization strength

### Model Validation Strategy
- **Stratified Split:** Maintain class balance in train/val sets
- **Early Stopping:** Prevent overfitting on small datasets
- **Class Weights:** Handle imbalanced Myanmar news sources
- **Cross-Validation:** For robust performance estimation

### Production Deployment Considerations
- **Model Serialization:** Save complete model + tokenizer + parameters
- **Inference Speed:** BiLSTM is relatively fast for classification
- **Memory Usage:** Monitor embedding + LSTM memory requirements
- **Model Updates:** Retrain periodically with new Myanmar news data

### Expected Training Results
- **Training Time:** 15-30 minutes on GPU (depends on dataset size)
- **Target Accuracy:** 75-85% for 3-class Myanmar sentiment classification
- **Convergence:** Usually within 20-30 epochs with early stopping
- **Model Size:** ~2-5MB (embedding + LSTM weights)