# BiLSTM Text Classifier with Leakage Detection & Proper Evaluation

This notebook trains a bidirectional LSTM classifier on Sinhala text with:
- **Leakage detection** to catch data contamination
- **Proper data split pipeline** (train-only adaptation)
- **Generator-holdout evaluation** to simulate production conditions


In [None]:
# Optional: install dependencies if running in a fresh environment
!pip install -q pandas scikit-learn tensorflow matplotlib seaborn joblib

: 

## Imports and Dependencies

In [None]:
import json
import os
from pathlib import Path

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import LabelEncoder, label_binarize

print(f"TensorFlow version: {tf.__version__}")


## Configuration and Paths

In [None]:
# Mount Google Drive (if running in Google Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_MOUNTED = True
    print("Google Drive mounted successfully!")
except ImportError:
    DRIVE_MOUNTED = False
    print("Not running in Google Colab. Using local paths.")

In [None]:
# Paths and basic settings
# Set base directory based on environment
if DRIVE_MOUNTED:
    # Google Drive path - Update this to match your Google Drive folder structure
    BASE_DIR = Path('/content/drive/MyDrive/Colab Notebooks/')
    DATA_DIR = BASE_DIR / 'dataset'
    MODEL_DIR = BASE_DIR / 'ml/models/bilstm_sinhala'
else:
    # Local paths
    DATA_DIR = Path('ml/dataset')
    MODEL_DIR = Path('ml/models/bilstm_sinhala')

assert DATA_DIR.exists(), f"Dataset directory not found at {DATA_DIR.absolute()}"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = DATA_DIR / 'train.jsonl'
VAL_PATH = DATA_DIR / 'val.jsonl'
TEST_PATH = DATA_DIR / 'test.jsonl'

SEED = 42
MAX_TOKENS = 30000  # vocab size for TextVectorization
SEQ_LEN = 400       # truncate/pad length (tune as needed)
BATCH_SIZE = 64
EPOCHS = 6
EMBED_DIM = 128
LSTM_UNITS = 128

tf.random.set_seed(SEED)
np.random.seed(SEED)

assert TRAIN_PATH.exists(), f'Missing train.jsonl at {TRAIN_PATH.absolute()}'
assert VAL_PATH.exists(), f'Missing val.jsonl at {VAL_PATH.absolute()}'
assert TEST_PATH.exists(), f'Missing test.jsonl at {TEST_PATH.absolute()}'

print(f"Data directory: {DATA_DIR.absolute()}")
print(f"Model directory: {MODEL_DIR.absolute()}")
print(f"Train path: {TRAIN_PATH.absolute()}")
print(f"Val path: {VAL_PATH.absolute()}")
print(f"Test path: {TEST_PATH.absolute()}")

## Data Loading

In [None]:
# Load JSONL files into DataFrames
def read_jsonl(path: Path) -> pd.DataFrame:
    return pd.read_json(path, lines=True)

train_df = read_jsonl(TRAIN_PATH)
val_df = read_jsonl(VAL_PATH)
test_df = read_jsonl(TEST_PATH)

for name, df in [('train', train_df), ('val', val_df), ('test', test_df)]:
    print(f'{name}: {len(df):,} rows | columns: {list(df.columns)}')

train_df.head()

## Leakage Detection

Check for duplicate samples across train/val/test splits and verify vectorizer fit scope.


In [None]:
# Leakage detection
def detect_leakage(train_df, val_df, test_df):
    """Detect data leakage across splits."""
    print("=" * 60)
    print("LEAKAGE DETECTION REPORT")
    print("=" * 60)
    
    # Check for duplicates within each split
    for name, df in [('train', train_df), ('val', val_df), ('test', test_df)]:
        dup_count = df['text'].duplicated().sum()
        print(f"\n{name.upper()} SET:")
        print(f"  Total samples: {len(df)}")
        print(f"  Duplicates within split: {dup_count}")
        if dup_count > 0:
            print(f"  ⚠️  WARNING: {dup_count} duplicate texts in {name} set!")
    
    # Check for overlaps between splits
    train_texts = set(train_df['text'].values)
    val_texts = set(val_df['text'].values)
    test_texts = set(test_df['text'].values)
    
    train_val_overlap = len(train_texts & val_texts)
    train_test_overlap = len(train_texts & test_texts)
    val_test_overlap = len(val_texts & test_texts)
    
    print(f"\nCROSS-SPLIT OVERLAPS:")
    print(f"  Train-Val overlap: {train_val_overlap} samples")
    if train_val_overlap > 0:
        print(f"  ⚠️  CRITICAL LEAKAGE DETECTED in Train-Val!")
    
    print(f"  Train-Test overlap: {train_test_overlap} samples")
    if train_test_overlap > 0:
        print(f"  ⚠️  CRITICAL LEAKAGE DETECTED in Train-Test!")
    
    print(f"  Val-Test overlap: {val_test_overlap} samples")
    if val_test_overlap > 0:
        print(f"  ⚠️  CRITICAL LEAKAGE DETECTED in Val-Test!")
    
    total_leakage = train_val_overlap + train_test_overlap + val_test_overlap
    if total_leakage == 0:
        print(f"\n✓ NO LEAKAGE DETECTED")
    else:
        print(f"\n✗ TOTAL LEAKAGE: {total_leakage} samples across splits")
    
    print("=" * 60)
    return total_leakage == 0

# Run leakage detection
is_clean = detect_leakage(train_df, val_df, test_df)


## Data Preprocessing

In [None]:
# Encode labels
label_encoder = LabelEncoder()
label_encoder.fit(train_df['label'])

def encode_labels(df: pd.DataFrame) -> np.ndarray:
    return label_encoder.transform(df['label'])

y_train = encode_labels(train_df)
y_val = encode_labels(val_df)
y_test = encode_labels(test_df)

NUM_CLASSES = len(label_encoder.classes_)
print('Classes:', label_encoder.classes_)

In [None]:
# Build TextVectorization ONLY on training data (prevents leakage)
text_vectorizer = tf.keras.layers.TextVectorization(
    max_tokens=MAX_TOKENS,
    output_mode='int',
    output_sequence_length=SEQ_LEN,
    standardize='lower_and_strip_punctuation'
)

# CRITICAL: Adapt ONLY on training text
print("Adapting vectorizer on TRAINING data only...")
text_vectorizer.adapt(train_df['text'].values)
print(f"Vectorizer vocabulary size: {text_vectorizer.vocabulary_size()}")

def make_dataset(texts, labels, training=False):
    """Create tf.data pipeline. Vectorization already fitted on train only."""
    ds = tf.data.Dataset.from_tensor_slices((texts.values, labels))
    if training:
        ds = ds.shuffle(10000, seed=SEED)
    ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return ds.map(lambda x, y: (text_vectorizer(x), y))

train_ds = make_dataset(train_df['text'], y_train, training=True)
val_ds = make_dataset(val_df['text'], y_val, training=False)
test_ds = make_dataset(test_df['text'], y_test, training=False)

for batch_x, batch_y in train_ds.take(1):
    print('Vectorized batch shape:', batch_x.shape, '| labels shape:', batch_y.shape)


## Model Architecture

The network uses an embedding layer initialized randomly, followed by a bidirectional LSTM stack and dropout regularization. The output layer is a dense softmax over the label set.

In [None]:
def build_model():
    inputs = tf.keras.Input(shape=(None,), dtype=tf.int64, name='tokens')
    x = tf.keras.layers.Embedding(MAX_TOKENS, EMBED_DIM, mask_zero=True)(inputs)
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(LSTM_UNITS, return_sequences=True))(x)
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(LSTM_UNITS // 2))(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    outputs = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs, name='bilstm_classifier')
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = build_model()
model.summary()

## Model Training

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=str(MODEL_DIR / 'checkpoint.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        mode='max'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=2,
        restore_best_weights=True
    )
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# Plot training curves
plt.figure(figsize=(8, 4))
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training vs validation accuracy')
plt.legend()
plt.grid(True)
plt.show()

## Model Evaluation

In [None]:
# Evaluate on the held-out test set
test_probs = model.predict(test_ds)
test_pred = np.argmax(test_probs, axis=1)

print('Test accuracy:', (test_pred == y_test).mean())
print('Classification report')
print(classification_report(y_test, test_pred, target_names=label_encoder.classes_))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_test, test_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion matrix')
plt.show()

In [None]:
# ROC Curve
# For multi-class ROC curve (one-vs-rest)
y_test_bin = label_binarize(y_test, classes=range(NUM_CLASSES))

plt.figure(figsize=(8, 6))

if NUM_CLASSES == 2:
    # Binary classification
    fpr, tpr, _ = roc_curve(y_test, test_probs[:, 1])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
else:
    # Multi-class: plot ROC for each class
    colors = plt.cm.tab10(np.linspace(0, 1, NUM_CLASSES))
    for i in range(NUM_CLASSES):
        fpr, tpr, _ = roc_curve(y_test_bin[:, i], test_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, color=colors[i], lw=2, 
                label=f'{label_encoder.classes_[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


## Model Persistence

In [None]:
# Save model and preprocessing assets
model.save(MODEL_DIR / 'saved_model')
joblib.dump(label_encoder, MODEL_DIR / 'label_encoder.joblib')
# Save the vectorizer config to recreate later
vectorizer_config = text_vectorizer.get_config()
vectorizer_weights = text_vectorizer.get_weights()
with open(MODEL_DIR / 'vectorizer_config.json', 'w', encoding='utf-8') as f:
    json.dump(vectorizer_config, f)
np.savez_compressed(MODEL_DIR / 'vectorizer_weights.npz', *vectorizer_weights)
print('Saved to', MODEL_DIR)

## Inference and Testing

In [None]:
# Safe inference helper (generator-style)
def predict_texts(texts, batch_size=32):
    """
    Predict on texts using streaming generator pattern.
    Safe for production and avoids batch-size artifacts.
    """
    if isinstance(texts, str):
        texts = [texts]
    
    results = []
    num_batches = (len(texts) + batch_size - 1) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(texts))
        batch = texts[start_idx:end_idx]
        
        ds = tf.data.Dataset.from_tensor_slices(batch).batch(len(batch))
        ds = ds.map(text_vectorizer).prefetch(tf.data.AUTOTUNE)
        probs = model.predict(ds, verbose=0)
        preds = np.argmax(probs, axis=1)
        labels = label_encoder.inverse_transform(preds)
        confidences = np.max(probs, axis=1)
        
        results.extend(zip(labels, confidences, probs))
    
    return results

# Test with sample texts from the test set
sample_indices = np.random.choice(len(test_df), min(3, len(test_df)), replace=False)
sample_texts = test_df.iloc[sample_indices]['text'].tolist()
predictions = predict_texts(sample_texts)

for i, (text, (pred_label, confidence, probs)) in enumerate(zip(sample_texts, predictions)):
    print(f"\nSample {i+1}:")
    print(f"Text: {text[:100]}...")
    print(f"Predicted: {pred_label} (confidence: {confidence:.4f})")
    print(f"Class probabilities: {dict(zip(label_encoder.classes_, probs))}")


## Generator-Holdout Evaluation

Simulate production conditions by evaluating on a streaming generator without batch effects.


In [None]:
def evaluate_generator_holdout(model, test_df, y_test, batch_size=32):
    """
    Generator-based holdout evaluation: process test data in small batches
    to simulate production streaming conditions and avoid batch-size artifacts.
    """
    print("\n" + "=" * 60)
    print("GENERATOR-HOLDOUT EVALUATION")
    print("=" * 60)
    
    all_probs = []
    all_preds = []
    
    # Process in smaller generator batches
    num_batches = (len(test_df) + batch_size - 1) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(test_df))
        
        batch_texts = test_df['text'].iloc[start_idx:end_idx].values
        batch_labels = y_test[start_idx:end_idx]
        
        # Create single-batch dataset for this generator batch
        batch_ds = tf.data.Dataset.from_tensor_slices(batch_texts).batch(len(batch_texts))
        batch_ds = batch_ds.map(text_vectorizer).prefetch(tf.data.AUTOTUNE)
        
        # Predict on this batch
        batch_probs = model.predict(batch_ds, verbose=0)
        all_probs.append(batch_probs)
        all_preds.extend(np.argmax(batch_probs, axis=1))
    
    # Concatenate all predictions
    test_probs = np.vstack(all_probs)
    test_pred = np.array(all_preds)
    
    # Compute metrics
    accuracy = (test_pred == y_test).mean()
    
    print(f"\nTest Accuracy: {accuracy:.4f}")
    print(f"\nClassification Report:")
    print(classification_report(y_test, test_pred, target_names=label_encoder.classes_))
    
    # Compute AUC per class
    y_test_bin = label_binarize(y_test, classes=range(NUM_CLASSES))
    if NUM_CLASSES == 2:
        fpr, tpr, _ = roc_curve(y_test, test_probs[:, 1])
        roc_auc = auc(fpr, tpr)
        print(f"\nROC AUC (Binary): {roc_auc:.4f}")
    else:
        print(f"\nPer-class ROC AUC:")
        for i in range(NUM_CLASSES):
            fpr, tpr, _ = roc_curve(y_test_bin[:, i], test_probs[:, i])
            roc_auc = auc(fpr, tpr)
            print(f"  {label_encoder.classes_[i]}: {roc_auc:.4f}")
    
    print("=" * 60)
    
    return test_probs, test_pred

# Run generator-holdout evaluation
test_probs_gen, test_pred_gen = evaluate_generator_holdout(model, test_df, y_test)
