<a href="https://colab.research.google.com/github/nghiemkhoa1235-boop/mafbj/blob/main/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ===============================
# üöÄ ULTIMATE FOOD RECOGNITION TRAINING
# ===============================


import os
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.optimizers import Adam, AdamW
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, CSVLogger, LearningRateScheduler, TerminateOnNaN, LambdaCallback
from tensorflow.keras.utils import Sequence
from tensorflow.keras import preprocessing
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.mixed_precision import set_global_policy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
import cv2
from google.colab import files, drive
import warnings
import json
import glob
import shutil
warnings.filterwarnings('ignore')


# ===============================
# üöÄ INITIAL SETUP & OPTIMIZATION
# ===============================


print("üöÄ STARTING ULTIMATE FOOD RECOGNITION TRAINING")
print("‚úÖ TensorFlow version:", tf.__version__)
print("‚úÖ GPU Available:", tf.test.is_gpu_available())


# T·ªëi ∆∞u h√≥a GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)


# Mount Google Drive
drive.mount('/content/drive', force_remount=True)
data_root = '/content/drive/MyDrive/AI/do_an'


# T·∫°o th∆∞ m·ª•c backup tr√™n Drive n·∫øu ch∆∞a c√≥ (v·ªõi subfolders)
backup_dir = os.path.join(data_root, 'training_backups')
os.makedirs(backup_dir, exist_ok=True)
os.makedirs(os.path.join(backup_dir, 'logs'), exist_ok=True)
os.makedirs(os.path.join(backup_dir, 'stage1_checkpoints'), exist_ok=True)
os.makedirs(os.path.join(backup_dir, 'stage2_checkpoints'), exist_ok=True)
os.makedirs(os.path.join(backup_dir, 'evaluation'), exist_ok=True)
print(f"üìÅ Backup directory created/verified: {backup_dir}")
print(f"   Subfolders: logs, stage1_checkpoints, stage2_checkpoints, evaluation")


# Test Drive write access
test_file = os.path.join(backup_dir, 'test_write_access.txt')
with open(test_file, 'w') as f:
    f.write('Test successful - Drive OK!')
print(f"‚úÖ Drive write test: {test_file} created. Check Drive to confirm.")


# ===============================
# üîÑ RESUME TRAINING FUNCTIONALITY
# ===============================


def get_latest_checkpoint(checkpoint_pattern):
    """Find the latest checkpoint file based on epoch in filename."""
    checkpoints = glob.glob(checkpoint_pattern)
    if not checkpoints:
        return None, 0
    print(f"üîç Found local checkpoints: {[os.path.basename(c) for c in checkpoints]}")
    # Extract epoch from filename, assuming pattern like 'stage1_epoch_XX.keras'
    epochs = []
    for cp in checkpoints:
        try:
            epoch = int(''.join(filter(str.isdigit, cp.split('_')[-1].split('.')[0])))
            epochs.append((cp, epoch))
        except:
            pass
    if epochs:
        latest = max(epochs, key=lambda x: x[1])
        print(f"‚úÖ Latest local checkpoint: {os.path.basename(latest[0])} (epoch {latest[1]})")
        return latest[0], latest[1]
    return None, 0




def load_training_state():
    """Load the current training state from JSON file."""
    state_file = 'training_state.json'
    # ∆Øu ti√™n load t·ª´ Drive n·∫øu c√≥
    drive_state = os.path.join(backup_dir, state_file)
    if os.path.exists(drive_state):
        shutil.copy(drive_state, state_file)
        print(f"üìÇ Loaded state from Drive: {drive_state}")
    if os.path.exists(state_file):
        with open(state_file, 'r') as f:
            return json.load(f)
    print("‚ö†Ô∏è No state file found. Will infer from checkpoints if available.")
    return {
        'current_stage': 0,
        'stage1_epochs_trained': 0,
        'stage2_epochs_trained': 0,
        'best_val_acc_stage1': 0.0,
        'model_loaded': False
    }


def save_training_state(stage, epochs_trained, best_val_acc=0.0):
    """Save the current training state to JSON file."""
    state = load_training_state()
    if stage == 1:
        state['current_stage'] = 1
        state['stage1_epochs_trained'] = epochs_trained
        state['best_val_acc_stage1'] = best_val_acc
    elif stage == 2:
        state['current_stage'] = 2
        state['stage2_epochs_trained'] = epochs_trained
    with open('training_state.json', 'w') as f:
        json.dump(state, f)
    # Copy to Drive
    drive_state = os.path.join(backup_dir, 'training_state.json')
    shutil.copy('training_state.json', drive_state)
    print(f"üíæ State saved to Drive: {drive_state} ‚úÖ")


def backup_file_to_drive(local_file, drive_subdir=''):
    """Backup a file to Drive backups directory with detailed log."""
    if not os.path.exists(local_file):
        print(f"‚ö†Ô∏è Local file not found: {local_file}")
        return
    drive_path = os.path.join(backup_dir, drive_subdir, os.path.basename(local_file))
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)
    try:
        shutil.copy(local_file, drive_path)
        if os.path.exists(drive_path):
            print(f"‚úÖ Backup success: {os.path.basename(local_file)} ‚Üí {drive_path}")
        # B·ªé CLEANUP - GI·ªÆ H·∫æT CHECKPOINTS
        # checkpoint_folder = os.path.join(backup_dir, drive_subdir)
        # cleanup_checkpoints(checkpoint_folder)
        else:
            print(f"‚ö†Ô∏è Backup copied but file not found after: {drive_path}")
    except Exception as e:
        print(f"‚ö†Ô∏è Backup failed for {local_file}: {e}")




# Callback ƒë·ªÉ backup checkpoint m·ªói epoch l√™n Drive
def drive_backup_callback(stage_name):
    def on_epoch_end(epoch, logs):
        # Backup checkpoint hi·ªán t·∫°i
        cp_pattern = f'{stage_name}_epoch_{epoch+1:02d}.keras'
        if os.path.exists(cp_pattern):
            backup_file_to_drive(cp_pattern, f'{stage_name}_checkpoints')
        # Backup CSV log t·∫°m th·ªùi
        csv_file = f'ultimate_{ "training" if stage_name == "stage1" else "fine_tuning" }_log.csv'
        if os.path.exists(csv_file):
            backup_file_to_drive(csv_file, 'logs')
        # Backup state.json
        current_epoch = epoch + 1
        save_training_state(1 if stage_name == "stage1" else 2, current_epoch)
        print(f"üîÑ Epoch {current_epoch} backup completed to Drive!")
    return LambdaCallback(on_epoch_end=on_epoch_end)


def resume_stage1(model, csv_logger_file='ultimate_training_log.csv'):
    """Resume stage 1 training."""
    state = load_training_state()
    print(f"üîç Checking resume for Stage 1: {state['stage1_epochs_trained']} epochs already trained")
    if state['stage1_epochs_trained'] > 0:
        print(f"üìÇ Resuming Stage 1 from epoch {state['stage1_epochs_trained'] + 1}")
        # Load latest checkpoint from local or Drive
        cp_file, cp_epoch = get_latest_checkpoint('stage1_epoch_*.keras')
        if not cp_file:
            # Try from Drive
            drive_cp_pattern = os.path.join(backup_dir, 'stage1_checkpoints', 'stage1_epoch_*.keras')
            drive_checkpoints = glob.glob(drive_cp_pattern)
            if drive_checkpoints:
                # Sort to get latest
                drive_epochs = []
                for cp in drive_checkpoints:
                    try:
                        epoch = int(''.join(filter(str.isdigit, os.path.basename(cp).split('_')[-1].split('.')[0])))
                        drive_epochs.append((cp, epoch))
                    except:
                        pass
                if drive_epochs:
                    drive_epochs.sort(key=lambda x: x[1], reverse=True)
                    cp_file, cp_epoch = drive_epochs[0]
                    shutil.copy(cp_file, os.path.basename(cp_file))
                    cp_file = os.path.basename(cp_file)
                    print(f"üìÇ Loaded checkpoint from Drive: {cp_file} (epoch {cp_epoch})")
        if cp_file:
            model.load_weights(cp_file)
            print(f"‚úÖ Loaded weights from {cp_file} (epoch {cp_epoch})")
            state['model_loaded'] = True
            # ∆Øu ti√™n checkpoint epoch n·∫øu m·ªõi h∆°n state
            if cp_epoch > state['stage1_epochs_trained']:
                print(f"üîÑ Updating state from checkpoint: {state['stage1_epochs_trained']} ‚Üí {cp_epoch}")
                state['stage1_epochs_trained'] = cp_epoch
                save_training_state(1, cp_epoch)
        # Fallback to CSV only if no checkpoint
        epochs_trained = state['stage1_epochs_trained']
        if os.path.exists(csv_logger_file) and not cp_file:
            # Try Drive first
            drive_csv = os.path.join(backup_dir, 'logs', csv_logger_file)
            if os.path.exists(drive_csv):
                shutil.copy(drive_csv, csv_logger_file)
                print(f"üìÇ Loaded CSV log from Drive: {drive_csv}")
            df_log = pd.read_csv(csv_logger_file)
            csv_epochs = len(df_log)
            print(f"üìä CSV log shows {csv_epochs} epochs trained")
            if csv_epochs > epochs_trained:
                epochs_trained = csv_epochs
                save_training_state(1, epochs_trained)
        return epochs_trained
    print("‚úÖ No previous training found for Stage 1. Starting from epoch 0.")
    return 0


def resume_stage2(model, csv_logger_file='ultimate_fine_tuning_log.csv'):
    """Resume stage 2 training."""
    state = load_training_state()
    print(f"üîç Checking resume for Stage 2: {state['stage2_epochs_trained']} epochs already trained")
    if state['stage2_epochs_trained'] > 0:
        print(f"üìÇ Resuming Stage 2 from epoch {state['stage2_epochs_trained'] + 1}")
        # Load latest checkpoint from local or Drive
        cp_file, epoch_offset = get_latest_checkpoint('stage2_epoch_*.keras')
        if not cp_file:
            # Try from Drive
            drive_cp_pattern = os.path.join(backup_dir, 'stage2_checkpoints', 'stage2_epoch_*.keras')
            drive_checkpoints = glob.glob(drive_cp_pattern)
            if drive_checkpoints:
                # Sort to get latest
                drive_checkpoints.sort(key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x).split('_')[-1].split('.')[0]))))
                cp_file = drive_checkpoints[-1]
                shutil.copy(cp_file, os.path.basename(cp_file))
                cp_file = os.path.basename(cp_file)
                print(f"üìÇ Loaded checkpoint from Drive: {cp_file}")
        if cp_file:
            model.load_weights(cp_file)
            print(f"‚úÖ Loaded weights from {cp_file}")
            state['model_loaded'] = True
        # Determine additional epochs from CSV log
        if os.path.exists(csv_logger_file):
            # Try Drive first
            drive_csv = os.path.join(backup_dir, 'logs', csv_logger_file)
            if os.path.exists(drive_csv):
                shutil.copy(drive_csv, csv_logger_file)
                print(f"üìÇ Loaded CSV log from Drive: {drive_csv}")
            df_log = pd.read_csv(csv_logger_file)
            additional_epochs = len(df_log)
            print(f"üìä CSV log shows {additional_epochs} epochs for Stage 2")
            if additional_epochs > state['stage2_epochs_trained']:
                state['stage2_epochs_trained'] = additional_epochs
                save_training_state(2, additional_epochs)
            return additional_epochs
        return state['stage2_epochs_trained']
    print("‚úÖ No previous training found for Stage 2. Starting from epoch 0.")
    return 0


# ===============================
# üéØ DATA COLLECTION
# ===============================


print("\n=== COLLECTING DATASET ===")


def collect_datasets_advanced():
    classes = []
    train_path = data_root


    if os.path.exists(train_path):
        for item in os.listdir(train_path):
            item_path = os.path.join(train_path, item)
            if os.path.isdir(item_path):
                possible_dirs = ['train', 'training', 'Train', 'Training']
                found = False
                for sub_dir in possible_dirs:
                    train_subdir = os.path.join(item_path, sub_dir)
                    if os.path.exists(train_subdir):
                        classes.append(item)
                        found = True
                        break
                if not found:
                    has_images = any(img.lower().endswith(('.jpg', '.jpeg', '.png'))
                                   for img in os.listdir(item_path))
                    if has_images:
                        classes.append(item)


    classes = sorted(classes)
    print(f"üìã Detected classes: {classes}")
    print(f"üéØ Total classes: {len(classes)}")


    def collect_split_data_advanced(split):
        filepaths = []
        labels = []
        split_variants = [split, split.lower(), split.upper(), split.capitalize()]


        for class_name in classes:
            for split_variant in split_variants:
                class_dir = os.path.join(data_root, class_name, split_variant)
                if os.path.exists(class_dir):
                    for img in os.listdir(class_dir):
                        if img.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
                            filepaths.append(os.path.join(class_dir, img))
                            labels.append(class_name)
                    break


            class_dir_direct = os.path.join(data_root, class_name)
            if split == 'train' and not any(class_name in label for label in labels):
                for img in os.listdir(class_dir_direct):
                    if img.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
                        filepaths.append(os.path.join(class_dir_direct, img))
                        labels.append(class_name)


        return pd.DataFrame({'filename': filepaths, 'class': labels})


    df_train = collect_split_data_advanced('train')
    df_val = collect_split_data_advanced('val')
    df_test = collect_split_data_advanced('test')


    return df_train, df_val, df_test, classes


# Load datasets
df_train, df_val, df_test, classes = collect_datasets_advanced()
num_classes = len(classes)


print(f"\nüìä Dataset Summary:")
print(f"Train: {len(df_train)} images")
print(f"Val: {len(df_val)} images")
print(f"Test: {len(df_test)} images")


# ===============================
# ‚öñÔ∏è CLASS BALANCE ANALYSIS
# ===============================


print("\n=== CLASS BALANCE ANALYSIS ===")
train_class_counts = df_train['class'].value_counts()
print("Training class distribution:")
for class_name, count in train_class_counts.items():
    print(f"  {class_name}: {count} samples")


# T√≠nh class weights
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(df_train['class']),
    y=df_train['class']
)
class_weight_dict = {i: class_weights[i] for i in range(len(classes))}
print("Class weights:", class_weight_dict)


# ===============================
# üöÄ ULTIMATE DATA GENERATOR
# ===============================


class UltimateDataGenerator(Sequence):
    def __init__(self, dataframe, target_size=(300, 300), batch_size=16, shuffle=True, augment=False, mixup_alpha=0.2, cutmix_alpha=1.0):
        self.dataframe = dataframe.reset_index(drop=True)
        self.target_size = target_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augment = augment
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.classes = classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.num_classes = len(self.classes)
        self.on_epoch_end()


        # Data Augmentation
        if self.augment:
            self.augmentation_pipeline = ImageDataGenerator(
                rotation_range=45,
                width_shift_range=0.3,
                height_shift_range=0.3,
                shear_range=0.3,
                zoom_range=0.4,
                horizontal_flip=True,
                vertical_flip=True,
                brightness_range=[0.7, 1.3],
                channel_shift_range=0.3,
                fill_mode='reflect'
            )


            self.advanced_augmentations = {
                'random_contrast': lambda x: tf.image.random_contrast(x, 0.8, 1.2),
                'random_saturation': lambda x: tf.image.random_saturation(x, 0.8, 1.2),
                'random_hue': lambda x: tf.image.random_hue(x, 0.1),
            }
        else:
            self.augmentation_pipeline = None
            self.advanced_augmentations = None
    def __len__(self):
        return int(np.ceil(len(self.dataframe) / self.batch_size))


    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_data = self.dataframe.iloc[batch_indices]


        X = np.zeros((len(batch_data), *self.target_size, 3), dtype=np.float32)
        y = np.zeros((len(batch_data), self.num_classes), dtype=np.float32)


        for i, (_, row) in enumerate(batch_data.iterrows()):
            img = self.load_and_preprocess_image(row['filename'])


            if self.augment and self.augmentation_pipeline and np.random.random() > 0.2:
                img = self.augmentation_pipeline.random_transform(img.astype(np.float32))


                if np.random.random() > 0.5:
                    aug_name = np.random.choice(list(self.advanced_augmentations.keys()))
                    try:
                        img = self.advanced_augmentations[aug_name](img)
                    except:
                        pass


            X[i] = img
            y[i, self.class_to_idx[row['class']]] = 1.0


        # MixUp Augmentation
        if self.augment and self.mixup_alpha > 0 and np.random.random() > 0.7:
            X, y = self.mixup(X, y, alpha=self.mixup_alpha)


        # CutMix Augmentation
        elif self.augment and self.cutmix_alpha > 0 and np.random.random() > 0.7:
            X, y = self.cutmix(X, y, alpha=self.cutmix_alpha)


        return X, y


    def mixup(self, batch_x, batch_y, alpha=0.2):
        batch_size = batch_x.shape[0]
        indices = np.random.permutation(batch_size)


        lam = np.random.beta(alpha, alpha, batch_size)
        lam = np.maximum(lam, 1 - lam)
        lam = lam.reshape((batch_size, 1, 1, 1))


        mixed_x = lam * batch_x + (1 - lam) * batch_x[indices]
        mixed_y = lam.reshape((batch_size, 1)) * batch_y + (1 - lam.reshape((batch_size, 1))) * batch_y[indices]


        return mixed_x, mixed_y


    def cutmix(self, batch_x, batch_y, alpha=1.0):
        batch_size, H, W, C = batch_x.shape
        indices = np.random.permutation(batch_size)


        lam = np.random.beta(alpha, alpha)
        cut_ratio = np.sqrt(1 - lam)
        cut_w = int(W * cut_ratio)
        cut_h = int(H * cut_ratio)


        cx = np.random.randint(W)
        cy = np.random.randint(H)


        x1 = np.clip(cx - cut_w // 2, 0, W)
        y1 = np.clip(cy - cut_h // 2, 0, H)
        x2 = np.clip(cx + cut_w // 2, 0, W)
        y2 = np.clip(cy + cut_h // 2, 0, H)


        batch_x_copy = batch_x.copy()
        batch_x_copy[:, y1:y2, x1:x2, :] = batch_x[indices, y1:y2, x1:x2, :]


        lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
        mixed_y = lam * batch_y + (1 - lam) * batch_y[indices]


        return batch_x_copy, mixed_y


    def load_and_preprocess_image(self, filepath):
        try:
            img = cv2.imread(filepath)
            if img is None:
                return self._create_placeholder_image()
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


            if img.shape[0] > self.target_size[0] or img.shape[1] > self.target_size[1]:
                interpolation = cv2.INTER_AREA
            else:
                interpolation = cv2.INTER_CUBIC


            img = cv2.resize(img, self.target_size, interpolation=interpolation)
            img = preprocess_input(img.astype(np.float32))


            return img
        except Exception as e:
            print(f"‚ùå Error processing {filepath}: {str(e)}")
            return self._create_placeholder_image()


    def _create_placeholder_image(self):
        return np.random.normal(0, 1, (*self.target_size, 3)).astype(np.float32)


    def on_epoch_end(self):
        self.indices = np.arange(len(self.dataframe))
        if self.shuffle:
            np.random.shuffle(self.indices)


# ===============================
# üõ†Ô∏è DATA PREPARATION
# ===============================


print("\n=== CREATING DATA GENERATORS ===")
target_size = (300, 300)
batch_size = 16


def fast_clean_dataframe(df):
    print(f"üîç Cleaning {len(df)} images...")
    valid_files = []


    for idx, row in df.iterrows():
        if idx % 100 == 0:
            print(f"   Progress: {idx}/{len(df)}")


        if os.path.exists(row['filename']):
            valid_files.append(idx)


    result = df.iloc[valid_files].reset_index(drop=True)
    print(f"‚úÖ Cleaned: {len(result)}/{len(df)} images valid")
    return result


# Clean data
df_train = fast_clean_dataframe(df_train)
df_val = fast_clean_dataframe(df_val)
df_test = fast_clean_dataframe(df_test)


print(f"üìä Final Dataset Summary:")
print(f"Train: {len(df_train)} images")
print(f"Val: {len(df_val)} images")
print(f"Test: {len(df_test)} images")


# Create generators
train_gen = UltimateDataGenerator(df_train, target_size, batch_size, shuffle=True, augment=True, mixup_alpha=0.2, cutmix_alpha=1.0)
val_gen = UltimateDataGenerator(df_val, target_size, batch_size, shuffle=False, augment=False)
if len(val_gen) == 0:
    print("‚ö†Ô∏è Warning: Val generator empty! Splitting from train.")
    from sklearn.model_selection import train_test_split
    df_train_split, df_val_split = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['class'])
    df_val = df_val_split
    df_train = df_train_split
    val_gen = UltimateDataGenerator(df_val, target_size, batch_size, shuffle=False, augment=False)
test_gen = UltimateDataGenerator(df_test, target_size, batch_size, shuffle=False, augment=False)


print("‚úÖ Data generators created successfully!")
print(f"Train batches: {len(train_gen)}")
print(f"Val batches: {len(val_gen)}")
print(f"Test batches: {len(test_gen)}")


# ===============================
# üß† ULTIMATE MODEL ARCHITECTURE
# ===============================


def create_ultimate_model():
    base_model = EfficientNetB4(
        weights='imagenet',
        include_top=False,
        input_shape=(*target_size, 3),
        pooling='avg'
    )
    base_model.trainable = False


    model = tf.keras.Sequential([
        base_model,


        # Layer 1
        layers.Dropout(0.3),
        layers.Dense(1536, activation='relu',
                    kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4)),
        layers.BatchNormalization(),


        # Layer 2
        layers.Dropout(0.4),
        layers.Dense(1024, activation='relu',
                    kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4)),
        layers.BatchNormalization(),


        # Layer 3
        layers.Dropout(0.35),
        layers.Dense(896, activation='relu',
                    kernel_regularizer=regularizers.l2(1e-4)),
        layers.BatchNormalization(),


        # Layer 4
        layers.Dropout(0.3),
        layers.Dense(768, activation='relu',
                    kernel_regularizer=regularizers.l2(1e-4)),
        layers.BatchNormalization(),


        # Layer 5
        layers.Dropout(0.25),
        layers.Dense(512, activation='relu',
                    kernel_regularizer=regularizers.l2(1e-4)),
        layers.BatchNormalization(),


        # Layer 6
        layers.Dropout(0.2),
        layers.Dense(384, activation='relu'),
        layers.BatchNormalization(),


        # Layer 7
        layers.Dropout(0.15),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),


        # Layer 8
        layers.Dropout(0.1),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),


        # Output
        layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])


    return model, base_model


print("üöÄ Creating ultimate model...")
model, base_model = create_ultimate_model()


# ===============================
# ‚öôÔ∏è MODEL COMPILATION
# ===============================


initial_learning_rate = 0.001


optimizer = AdamW(
    learning_rate=initial_learning_rate,
    weight_decay=0.0001,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07
)


model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy',
             tf.keras.metrics.Precision(name='precision'),
             tf.keras.metrics.Recall(name='recall'),
             tf.keras.metrics.AUC(name='auc')]
)
print("‚úÖ Model compiled successfully!")
model.summary()


# ===============================
# üìà TRAINING SETUP
# ===============================


def advanced_lr_schedule(epoch):
    if epoch < 5:
        return initial_learning_rate * (epoch + 1) / 5
    elif epoch < 30:
        return initial_learning_rate * 0.5 * (1 + np.cos(np.pi * (epoch - 5) / 25))
    elif epoch < 60:
        return initial_learning_rate * 0.1 * (1 + np.cos(np.pi * (epoch - 30) / 30))
    elif epoch < 90:
        return initial_learning_rate * 0.01
    else:
        return initial_learning_rate * 0.001


# Callbacks
callbacks = [
    LearningRateScheduler(advanced_lr_schedule, verbose=1),
    ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,
        patience=6,
        min_lr=1e-9,
        verbose=1,
        mode='max'
    ),
    EarlyStopping(
        monitor='val_accuracy',
        patience=35,
        restore_best_weights=True,
        verbose=1,
        mode='max'
    ),
    ModelCheckpoint(
        'best_ultimate_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    CSVLogger('ultimate_training_log.csv', append=True),
    TerminateOnNaN()
]


print("‚úÖ Training callbacks ready!")


# ===============================
# üèãÔ∏è TRAINING EXECUTION
# ===============================


print("\n" + "üî•" * 60)
print("üî• STAGE 1: TRAINING HEAD LAYERS")
print("üî•" * 60)


print("üöÄ Starting Stage 1 training...")
initial_epoch_stage1 = resume_stage1(model)


# Update callbacks for stage 1 to save every epoch
stage1_checkpoints = ModelCheckpoint(
    'stage1_epoch_{epoch:02d}.keras',
    monitor='val_accuracy',
    save_best_only=False,  # Save every epoch for resume
    save_weights_only=False,  # Save full model for optimizer state
    verbose=1
)


# Th√™m drive backup callback cho m·ªói epoch
stage1_backup = drive_backup_callback('stage1')


callbacks_stage1 = [
    LearningRateScheduler(advanced_lr_schedule, verbose=1),
    ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,
        patience=6,
        min_lr=1e-9,
        verbose=1,
        mode='max'
    ),
    EarlyStopping(
        monitor='val_accuracy',
        patience=35,
        restore_best_weights=True,
        verbose=1,
        mode='max'
    ),
    stage1_checkpoints,  # Updated
    ModelCheckpoint(  # Keep best only as well
        'best_ultimate_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    CSVLogger('ultimate_training_log.csv', append=True),
    TerminateOnNaN(),
    stage1_backup  # Backup m·ªói epoch l√™n Drive
]


history_stage1 = model.fit(
    train_gen,
    epochs=100,
    initial_epoch=initial_epoch_stage1,
    validation_data=val_gen,
    callbacks=callbacks_stage1,
    verbose=1,
    class_weight=class_weight_dict
)


epochs_trained_stage1 = initial_epoch_stage1 + len(history_stage1.epoch) if hasattr(history_stage1, 'epoch') and history_stage1.epoch else initial_epoch_stage1
save_training_state(1, epochs_trained_stage1, max(history_stage1.history.get('val_accuracy', [0])) or 0)


# Backup stage 1 files to Drive (final)
backup_file_to_drive('ultimate_training_log.csv', 'logs')
for cp in glob.glob('stage1_epoch_*.keras'):
    backup_file_to_drive(cp, 'stage1_checkpoints')
backup_file_to_drive('best_ultimate_model.keras')


print("üéâ Stage 1 training completed!")


# ===============================
# üîß FINE-TUNING STAGE
# ===============================


val_accs = history_stage1.history.get('val_accuracy', [])
best_stage1_acc = max(val_accs) if val_accs else max(history_stage1.history.get('accuracy', [0.0]))
print(f"\nüìä Stage 1 Best Accuracy: {best_stage1_acc:.4f}")
if best_stage1_acc >= 0.75:
    print("üöÄ PROCEEDING TO STAGE 2: FINE-TUNING")


    base_model.trainable = True
    for layer in base_model.layers[:150]:
        layer.trainable = False


    print(f"üîì Trainable layers: {sum([layer.trainable for layer in base_model.layers])}/{len(base_model.layers)}")


    fine_tune_optimizer = AdamW(
        learning_rate=initial_learning_rate * 0.01,
        weight_decay=0.00001,
        beta_1=0.9,
        beta_2=0.999
    )


    model.compile(
        optimizer=fine_tune_optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy',
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(name='recall'),
                tf.keras.metrics.AUC(name='auc')]
    )


    # Update callbacks for stage 2 to save every epoch
    stage2_checkpoints = ModelCheckpoint(
        'stage2_epoch_{epoch:02d}.keras',
        monitor='val_accuracy',
        save_best_only=False,  # Save every epoch
        save_weights_only=False,
        verbose=1
    )


    # Th√™m drive backup callback cho m·ªói epoch
    stage2_backup = drive_backup_callback('stage2')


    fine_tune_callbacks = [
        ReduceLROnPlateau(
            monitor='val_accuracy',
            factor=0.5,
            patience=8,
            min_lr=1e-10,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=30,
            restore_best_weights=True,
            verbose=1
        ),
        stage2_checkpoints,  # Updated
        ModelCheckpoint(
            'best_ultimate_fine_tuned.keras',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        CSVLogger('ultimate_fine_tuning_log.csv', append=True),
        stage2_backup  # Backup m·ªói epoch l√™n Drive
    ]


    print("\n" + "üî•" * 60)
    print("üî• STAGE 2: ADVANCED FINE-TUNING")
    print("üî•" * 60)


    initial_epoch_stage2 = resume_stage2(model)


    history_stage2 = model.fit(
        train_gen,
        epochs=150,
        initial_epoch=initial_epoch_stage2,
        validation_data=val_gen,
        callbacks=fine_tune_callbacks,
        verbose=1,
        class_weight=class_weight_dict
    )


    additional_epochs_stage2 = initial_epoch_stage2 + len(history_stage2.epoch) if hasattr(history_stage2, 'epoch') and history_stage2.epoch else initial_epoch_stage2
    save_training_state(2, additional_epochs_stage2)


    # Backup stage 2 files to Drive (final)
    backup_file_to_drive('ultimate_fine_tuning_log.csv', 'logs')
    for cp in glob.glob('stage2_epoch_*.keras'):
        backup_file_to_drive(cp, 'stage2_checkpoints')
    backup_file_to_drive('best_ultimate_fine_tuned.keras')


    final_history = history_stage2
else:
    print("‚ö†Ô∏è Stage 1 accuracy below threshold. Skipping fine-tuning.")
    final_history = history_stage1


# ===============================
# üìä EVALUATION & RESULTS
# ===============================


print("\n=== FINAL EVALUATION ===")


# Standard evaluation
test_loss, test_accuracy, test_precision, test_recall, test_auc = model.evaluate(test_gen, verbose=0)
print(f"üéØ TEST ACCURACY: {test_accuracy:.4f}")
print(f"üéØ TEST PRECISION: {test_precision:.4f}")
print(f"üéØ TEST RECALL: {test_recall:.4f}")
print(f"üéØ TEST AUC: {test_auc:.4f}")


# Ensemble prediction
def ensemble_predict(generator, model, n_rounds=5):
    all_predictions = []


    for i in range(n_rounds):
        print(f"üîÑ Ensemble round {i+1}/{n_rounds}")
        generator.on_epoch_end()
        pred = model.predict(generator, verbose=0)
        all_predictions.append(pred)


    return np.mean(all_predictions, axis=0)


Y_pred_ensemble = ensemble_predict(test_gen, model)
y_pred_ensemble = np.argmax(Y_pred_ensemble, axis=1)


# True labels
true_labels = []
for i in range(len(test_gen)):
    _, y_batch = test_gen[i]
    true_labels.extend(np.argmax(y_batch, axis=1))
true_labels = np.array(true_labels)


ensemble_accuracy = np.sum(y_pred_ensemble == true_labels) / len(true_labels)
print(f"üéØ ENSEMBLE ACCURACY: {ensemble_accuracy:.4f}")


final_accuracy = max(test_accuracy, ensemble_accuracy)


# Classification report
print("\nüìä DETAILED CLASSIFICATION REPORT:")
print(classification_report(true_labels, y_pred_ensemble, target_names=classes, digits=4))


# Confusion Matrix
plt.figure(figsize=(12, 10))
cm = confusion_matrix(true_labels, y_pred_ensemble)
sns.heatmap(cm, annot=True, fmt='d', cmap='YlOrRd', xticklabels=classes, yticklabels=classes)
plt.title(f'Confusion Matrix - Accuracy: {final_accuracy:.4f}', fontsize=16)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()


# Backup evaluation files
backup_file_to_drive('confusion_matrix.png', 'evaluation')


# ===============================
# üíæ SAVE MODEL
# ===============================


model.save('ULTIMATE_FOOD_RECOGNITION_MODEL.keras')


# Save to Google Drive
try:
    drive_model_path = os.path.join(backup_dir, 'ULTIMATE_FOOD_RECOGNITION_MODEL.keras')
    shutil.copy('ULTIMATE_FOOD_RECOGNITION_MODEL.keras', drive_model_path)
    print(f"‚úÖ Model saved to Drive: {drive_model_path}")
except Exception as e:
    print(f"‚ö†Ô∏è Could not save model to Drive: {e}")


# ===============================
# üéâ FINAL RESULTS
# ===============================


print("\n" + "üéâ" * 40)
print("üéâ TRAINING COMPLETED SUCCESSFULLY!")
print("üéâ" * 40)


print(f"\nüìä FINAL PERFORMANCE SUMMARY:")
print(f"üéØ FINAL ACCURACY: {final_accuracy:.4f}")
print(f"üéØ ENSEMBLE ACCURACY: {ensemble_accuracy:.4f}")
print(f"üéØ STANDARD ACCURACY: {test_accuracy:.4f}")
print(f"üéØ PRECISION: {test_precision:.4f}")
print(f"üéØ RECALL: {test_recall:.4f}")
print(f"üéØ AUC: {test_auc:.4f}")


# Performance assessment
if final_accuracy >= 0.90:
    print("\nüèÜ EXCEPTIONAL PERFORMANCE! üèÜ")
    print("üöÄ Model is production-ready!")
elif final_accuracy >= 0.85:
    print("\nüéØ EXCELLENT PERFORMANCE!")
    print("üöÄ Model is highly accurate!")
elif final_accuracy >= 0.75:
    print("\n‚úÖ VERY GOOD PERFORMANCE!")
    print("üí™ Model is reliable!")
else:
    print("\n‚ö†Ô∏è GOOD PERFORMANCE")
    print("üìà Consider further optimization")


print(f"\n‚úÖ MODEL SAVED: ULTIMATE_FOOD_RECOGNITION_MODEL.keras")
print("üöÄ READY FOR DEPLOYMENT!")


print("\n" + "="*60)
print("üéØ ULTIMATE FOOD RECOGNITION SYSTEM READY!")
print("="*60)

