In [None]:
# %% [markdown]
# # 🎯 Model Training - LSTM & BiLSTM

# %%
import sys
sys.path.append('..')

import tensorflow as tf
from config.model_config import ModelConfig
from src.utils.gpu_utils import GPUMemoryManager
from src.data.preprocessing import DataPreprocessor
from src.models.bilstm_attention import BiLSTMAttentionModel
from src.models.lstm_attention import LSTMAttentionModel
from src.training.trainer import ModelTrainer
from src.utils.helpers import save_tokenizer
import os

# %% [markdown]
# ## 1. GPU Setup

# %%
GPUMemoryManager.clear_session()
GPUMemoryManager.setup_gpu(
    memory_limit_mb=ModelConfig.GPU_MEMORY_LIMIT,
    allow_growth=ModelConfig.GPU_MEMORY_GROWTH
)

if ModelConfig.USE_MIXED_PRECISION:
    GPUMemoryManager.enable_mixed_precision()

GPUMemoryManager.get_memory_info()

# %% [markdown]
# ## 2. Configuration

# %%
config = ModelConfig.to_dict()

print("📋 Configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

print("\n💾 Memory Estimate:")
for key, value in ModelConfig.estimate_memory().items():
    print(f"   {key}: {value}")

# %% [markdown]
# ## 3. Data Preprocessing

# %%
preprocessor = DataPreprocessor(
    max_vocab_en=config['max_vocab_size_en'],
    max_vocab_vi=config['max_vocab_size_vi'],
    min_frequency=ModelConfig.MIN_WORD_FREQUENCY
)

# Load data
df = preprocessor.load_data(
    path_en='../data/raw/dataset_en.txt',
    path_vi='../data/raw/dataset_vi.txt',
    max_length_en=config['max_length_en'],
    max_length_vi=config['max_length_vi']
)

print(f"Dataset: {df.shape}")

# %%
# Split
train_df, val_df, test_df = preprocessor.split_data(df)
print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# Build tokenizers
tokenizer_en, tokenizer_vi = preprocessor.build_tokenizers(train_df)

# Save
os.makedirs(ModelConfig.TOKENIZER_PATH, exist_ok=True)
save_tokenizer(tokenizer_en, f'{ModelConfig.TOKENIZER_PATH}/tokenizer_en.pkl')
save_tokenizer(tokenizer_vi, f'{ModelConfig.TOKENIZER_PATH}/tokenizer_vi.pkl')

# %%
# Prepare sequences
en_train, vi_in_train, vi_out_train = preprocessor.prepare_sequences(
    train_df, config['max_length_en'], config['max_length_vi']
)
en_val, vi_in_val, vi_out_val = preprocessor.prepare_sequences(
    val_df, config['max_length_en'], config['max_length_vi']
)

print(f"Training sequences: {en_train.shape}")

# %% [markdown]
# ## 4. Build BiLSTM Model

# %%
model_builder = BiLSTMAttentionModel(config)
bilstm_model = model_builder.build(
    vocab_size_en=config['max_vocab_size_en'],
    vocab_size_vi=config['max_vocab_size_vi'],
    max_len_en=config['max_length_en'],
    max_len_vi=config['max_length_vi']
)

bilstm_model.summary()

# %% [markdown]
# ## 5. Train BiLSTM

# %%
steps_per_epoch = len(en_train) // config['batch_size']
config['total_steps'] = steps_per_epoch * config['epochs']

trainer = ModelTrainer(model=bilstm_model, config=config)

bilstm_history = trainer.train(
    train_data=(en_train, vi_in_train, vi_out_train),
    val_data=(en_val, vi_in_val, vi_out_val)
)

# %%
# Save model
os.makedirs(ModelConfig.MODEL_SAVE_PATH, exist_ok=True)
trainer.save_model(f'{ModelConfig.MODEL_SAVE_PATH}/bilstm_model.h5')

# %% [markdown]
# ## 6. Build LSTM Model (for comparison)

# %%
GPUMemoryManager.clear_session()

lstm_builder = LSTMAttentionModel(config)
lstm_model = lstm_builder.build(
    vocab_size_en=config['max_vocab_size_en'],
    vocab_size_vi=config['max_vocab_size_vi'],
    max_len_en=config['max_length_en'],
    max_len_vi=config['max_length_vi']
)

lstm_model.summary()

# %% [markdown]
# ## 7. Train LSTM

# %%
lstm_trainer = ModelTrainer(model=lstm_model, config=config)

lstm_history = lstm_trainer.train(
    train_data=(en_train, vi_in_train, vi_out_train),
    val_data=(en_val, vi_in_val, vi_out_val)
)

# %%
# Save
lstm_trainer.save_model(f'{ModelConfig.MODEL_SAVE_PATH}/lstm_model.h5')

# %% [markdown]
# ## 8. Compare Results

# %%
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(bilstm_history.history['loss'], label='BiLSTM Train')
axes[0].plot(bilstm_history.history['val_loss'], label='BiLSTM Val')
axes[0].plot(lstm_history.history['loss'], label='LSTM Train', linestyle='--')
axes[0].plot(lstm_history.history['val_loss'], label='LSTM Val', linestyle='--')
axes[0].set_title('Model Loss Comparison')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(bilstm_history.history['accuracy'], label='BiLSTM Train')
axes[1].plot(bilstm_history.history['val_accuracy'], label='BiLSTM Val')
axes[1].plot(lstm_history.history['accuracy'], label='LSTM Train', linestyle='--')
axes[1].plot(lstm_history.history['val_accuracy'], label='LSTM Val', linestyle='--')
axes[1].set_title('Model Accuracy Comparison')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../logs/comparison.png', dpi=300)
plt.show()

# %%
print(f"BiLSTM - Final Val Loss: {bilstm_history.history['val_loss'][-1]:.4f}")
print(f"LSTM - Final Val Loss: {lstm_history.history['val_loss'][-1]:.4f}")
