In [None]:
# %% [markdown]
# # 📊 Model Evaluation & Testing

# %%
import sys
sys.path.append('..')

import tensorflow as tf
import pickle
import numpy as np
from src.evaluation.metrics import BLEUScore
from src.evaluation.beam_search import BeamSearchDecoder
from config.model_config import ModelConfig

# %% [markdown]
# ## 1. Load Models & Tokenizers

# %%
# Load BiLSTM model
bilstm_model = tf.keras.models.load_model(
    '../models/saved_models/bilstm_model.h5',
    compile=False
)

# Load LSTM model
lstm_model = tf.keras.models.load_model(
    '../models/saved_models/lstm_model.h5',
    compile=False
)

# Load tokenizers
with open('../models/tokenizers/tokenizer_en.pkl', 'rb') as f:
    tokenizer_en = pickle.load(f)

with open('../models/tokenizers/tokenizer_vi.pkl', 'rb') as f:
    tokenizer_vi = pickle.load(f)

print("✅ Models and tokenizers loaded")

# %% [markdown]
# ## 2. Setup Decoders

# %%
config = ModelConfig.to_dict()

bilstm_decoder = BeamSearchDecoder(
    bilstm_model, tokenizer_en, tokenizer_vi,
    config['max_length_en'], config['max_length_vi']
)

lstm_decoder = BeamSearchDecoder(
    lstm_model, tokenizer_en, tokenizer_vi,
    config['max_length_en'], config['max_length_vi']
)

bleu_scorer = BLEUScore()

# %% [markdown]
# ## 3. Test Translations

# %%
test_sentences = [
    "Hello, how are you?",
    "I love machine learning.",
    "The weather is beautiful today.",
    "Can you help me with this problem?",
    "Thank you for your time."
]

print("="*80)
print("TRANSLATION TESTS")
print("="*80)

for text in test_sentences:
    print(f"\n📝 Input: {text}")

    # BiLSTM
    bilstm_trans = bilstm_decoder.translate(text)
    print(f"🔵 BiLSTM: {bilstm_trans}")

    # LSTM
    lstm_trans = lstm_decoder.translate(text)
    print(f"🟢 LSTM: {lstm_trans}")
    print("-"*80)

# %% [markdown]
# ## 4. BLEU Score Evaluation

# %%
# Load test data
from src.data.preprocessing import DataPreprocessor

preprocessor = DataPreprocessor(
    max_vocab_en=config['max_vocab_size_en'],
    max_vocab_vi=config['max_vocab_size_vi']
)

df = preprocessor.load_data(
    path_en='../data/raw/dataset_en.txt',
    path_vi='../data/raw/dataset_vi.txt',
    max_length_en=config['max_length_en'],
    max_length_vi=config['max_length_vi']
)

# Get test set
_, _, test_df = preprocessor.split_data(df)

# %%
# Evaluate on test set (sample 100 for speed)
test_sample = test_df.sample(n=min(100, len(test_df)))

bilstm_bleu_scores = []
lstm_bleu_scores = []

for idx, row in test_sample.iterrows():
    en_text = row['english']
    vi_ref = row['vietnamese'].replace('START ', '').replace(' END', '')

    # BiLSTM
    bilstm_trans = bilstm_decoder.translate(en_text)
    bilstm_bleu = bleu_scorer.compute(vi_ref, bilstm_trans)
    bilstm_bleu_scores.append(bilstm_bleu)

    # LSTM
    lstm_trans = lstm_decoder.translate(en_text)
    lstm_bleu = bleu_scorer.compute(vi_ref, lstm_trans)
    lstm_bleu_scores.append(lstm_bleu)

# %%
print("\n" + "="*80)
print("BLEU SCORE EVALUATION (100 samples)")
print("="*80)
print(f"BiLSTM - Average BLEU: {np.mean(bilstm_bleu_scores):.2f}")
print(f"LSTM - Average BLEU: {np.mean(lstm_bleu_scores):.2f}")
print(f"Improvement: {np.mean(bilstm_bleu_scores) - np.mean(lstm_bleu_scores):.2f} points")

# %% [markdown]
# ## 5. Visualization

# %%
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(bilstm_bleu_scores, bins=20, alpha=0.7, label='BiLSTM', edgecolor='black')
ax.hist(lstm_bleu_scores, bins=20, alpha=0.7, label='LSTM', edgecolor='black')

ax.set_xlabel('BLEU Score')
ax.set_ylabel('Frequency')
ax.set_title('BLEU Score Distribution Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../logs/bleu_comparison.png', dpi=300)
plt.show()

# %% [markdown]
# ## 6. Summary Report

# %%
print("\n" + "="*80)
print("FINAL EVALUATION SUMMARY")
print("="*80)
print(f"\nBiLSTM Model:")
print(f"  - Average BLEU: {np.mean(bilstm_bleu_scores):.2f}")
print(f"  - Min BLEU: {np.min(bilstm_bleu_scores):.2f}")
print(f"  - Max BLEU: {np.max(bilstm_bleu_scores):.2f}")
print(f"  - Parameters: {bilstm_model.count_params():,}")

print(f"\nLSTM Model:")
print(f"  - Average BLEU: {np.mean(lstm_bleu_scores):.2f}")
print(f"  - Min BLEU: {np.min(lstm_bleu_scores):.2f}")
print(f"  - Max BLEU: {np.max(lstm_bleu_scores):.2f}")
print(f"  - Parameters: {lstm_model.count_params():,}")

print(f"\nConclusion:")
print(f"  BiLSTM performs {np.mean(bilstm_bleu_scores) - np.mean(lstm_bleu_scores):.2f} BLEU points better")
print("="*80)
