In [None]:
# Kiểm tra setup
try:
    from shared_variables import *
except:
    # Tải các biến và hàm từ notebook trước
    %run "1_Setup_and_Utils.ipynb"
    %run "3_Model_Definition.ipynb"

# Lấy SignLanguageModel class
SignLanguageModel = load_variable('SignLanguageModel')

# Hàm huấn luyện mô hình
def train_model(data_dir, epochs=None, batch_size=None):
    """
    Train the sign language model with automatic parameter tuning
    """
    print(f"Training model with data from {data_dir}")
    
    # Ensure the data directory exists
    if not os.path.exists(data_dir):
        print(f"Error: Directory {data_dir} does not exist")
        return None
    
    # Count samples and classes
    total_images = 0
    min_class_images = float('inf')
    classes = []
    
    for item in os.listdir(data_dir):
        item_path = os.path.join(data_dir, item)
        if os.path.isdir(item_path):
            classes.append(item)
            img_count = len([f for f in os.listdir(item_path) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            total_images += img_count
            min_class_images = min(min_class_images, img_count)
            print(f"  - Class {item}: {img_count} images")
    
    print(f"Total classes: {len(classes)}")
    print(f"Total images: {total_images}")
    print(f"Minimum images per class: {min_class_images}")
    
    if len(classes) == 0:
        print("Error: No class directories found!")
        return None
    
    # Auto-calculate batch_size and epochs if not provided
    if batch_size is None:
        if min_class_images <= 8:
            batch_size = 1
        elif min_class_images <= 16:
            batch_size = 4
        elif min_class_images <= 32:
            batch_size = 8
        elif min_class_images <= 64:
            batch_size = 16
        else:
            batch_size = 32
        print(f"Auto-selected batch_size = {batch_size}")
    
    if epochs is None:
        if total_images < 100:
            epochs = 30
        elif total_images < 500:
            epochs = 20
        elif total_images < 1000:
            epochs = 15
        else:
            epochs = 10
        print(f"Auto-selected epochs = {epochs}")
    
    # Initialize model
    model = SignLanguageModel()
    
    # Load and prepare data
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(64, 64),
        batch_size=batch_size,
        color_mode='grayscale',
        shuffle=True
    )
    
    validation_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(64, 64),
        batch_size=batch_size,
        color_mode='grayscale',
        shuffle=True
    )
    
    # Normalize data
    normalization_layer = tf.keras.layers.Rescaling(1./255)
    train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
    validation_ds = validation_ds.map(lambda x, y: (normalization_layer(x), y))
    
    # Data augmentation - adjust strength based on dataset size
    augmentation_strength = 0.2 if total_images > 1000 else (0.3 if total_images > 500 else 0.4)
    data_augmentation = tf.keras.Sequential([
        tf.keras.layers.RandomRotation(augmentation_strength),
        tf.keras.layers.RandomTranslation(augmentation_strength, augmentation_strength),
        tf.keras.layers.RandomZoom(augmentation_strength),
        tf.keras.layers.RandomContrast(augmentation_strength)
    ])
    
    # Apply augmentation
    train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
    
    # Optimize performance
    AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
    validation_ds = validation_ds.prefetch(buffer_size=AUTOTUNE)
    
    # Callbacks
    early_stopping_patience = 5 if total_images < 500 else 3
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            patience=early_stopping_patience, 
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            factor=0.2, 
            patience=2,
            verbose=1
        )
    ]
    
    # Train the model
    print("Starting model training...")
    history = model.model.fit(
        train_ds,
        validation_data=validation_ds,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training')
    plt.plot(history.history['val_accuracy'], label='Validation')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training')
    plt.plot(history.history['val_loss'], label='Validation')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Save the model
    model_file = "models/sign_model.h5"
    model.model.save(model_file)
    print(f"Model saved to {model_file}")
    
    # Lưu model để chia sẻ
    save_variable(model, 'trained_model')
    
    return model

# Lấy đường dẫn dataset từ notebook trước
dataset_dir = load_variable('dataset_dir')
if dataset_dir is None:
    dataset_dir = "dataset"  # Default

# Nhập tham số huấn luyện
print(f"Dataset directory: {dataset_dir}")
print("Enter epochs (press Enter for automatic):")
epochs_input = input()
epochs = int(epochs_input) if epochs_input else None

print("Enter batch size (press Enter for automatic):")
batch_size_input = input()
batch_size = int(batch_size_input) if batch_size_input else None

# Huấn luyện mô hình
trained_model = train_model(dataset_dir, epochs, batch_size)

# Tùy chọn download mô hình
print("Do you want to download the trained model? (y/n)")
download_choice = input().lower()
if download_choice == 'y':
    files.download("models/sign_model.h5")