# Pokemon Classifier Training Notebook

Train a MobileNetV3-based classifier to identify Gen 1 Pokemon (151 classes).

**Run this notebook in Google Colab for GPU acceleration.**

**Note:** This notebook uses local Colab storage. Files will not persist between sessions.
Make sure to download your trained model before the session ends!

In [None]:
#@title 1. Setup and Configuration
#@markdown Install dependencies and configure paths (no Google Drive required)

# Install required packages
!pip install -q tensorflow tensorflowjs kaggle

# Configuration
CONFIG = {
    # Paths (local Colab storage - does not persist between sessions)
    'data_dir': '/content/data',
    'models_dir': '/content/models',
    
    # Dataset
    'kaggle_dataset': 'lantian773030/pokemonclassification',
    
    # Image settings
    'image_size': 224,
    'batch_size': 32,
    
    # Training
    'epochs_frozen': 10,
    'epochs_unfrozen': 15,
    'learning_rate_frozen': 1e-3,
    'learning_rate_unfrozen': 1e-5,
    'dropout_rate': 0.5,
    'validation_split': 0.15,
    'test_split': 0.15,
    
    # Fine-tuning
    'unfreeze_layers': 20,
    
    # Augmentation
    'rotation_range': 20,
    'zoom_range': 0.15,
    'horizontal_flip': True,
    
    # Callbacks
    'early_stopping_patience': 5,
    'reduce_lr_patience': 3,
}

import os
os.makedirs(CONFIG['data_dir'], exist_ok=True)
os.makedirs(CONFIG['models_dir'], exist_ok=True)

print("Configuration complete")
print(f"  Data directory: {CONFIG['data_dir']}")
print(f"  Models directory: {CONFIG['models_dir']}")
print("")
print("IMPORTANT: Files are stored locally and will be deleted when the session ends.")
print("Make sure to download your trained model before disconnecting!")

In [None]:
#@title 2. Download Dataset from Kaggle
#@markdown Upload your kaggle.json file when prompted

import os
from pathlib import Path

# Check if dataset already exists
data_path = Path(CONFIG['data_dir'])
existing_dirs = list(data_path.glob("*/"))

if len(existing_dirs) > 100:
    print(f"Dataset already downloaded ({len(existing_dirs)} directories found)")
else:
    # Upload kaggle.json
    from google.colab import files
    print("Upload your kaggle.json file:")
    uploaded = files.upload()
    
    # Setup Kaggle credentials
    os.makedirs('/root/.kaggle', exist_ok=True)
    with open('/root/.kaggle/kaggle.json', 'wb') as f:
        f.write(uploaded['kaggle.json'])
    os.chmod('/root/.kaggle/kaggle.json', 0o600)
    
    # Download dataset
    !kaggle datasets download -d {CONFIG['kaggle_dataset']} -p {CONFIG['data_dir']} --unzip
    
    print("Dataset downloaded")

# Verify dataset
def count_images(directory):
    count = 0
    for root, dirs, files in os.walk(directory):
        count += len([f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
    return count

total_images = count_images(CONFIG['data_dir'])
print(f"  Total images: {total_images}")

In [None]:
#@title 3. Explore Dataset
#@markdown Visualize sample images and class distribution

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import random

# Find Pokemon directories
data_path = Path(CONFIG['data_dir'])

# Try to find the directory containing Pokemon folders
pokemon_dir = None
for item in data_path.iterdir():
    if item.is_dir():
        subdirs = list(item.iterdir())
        if len([d for d in subdirs if d.is_dir()]) > 50:
            pokemon_dir = item
            break

if pokemon_dir is None:
    # Data might be directly in data_dir
    pokemon_dir = data_path

class_dirs = sorted([d for d in pokemon_dir.iterdir() if d.is_dir()])
print(f"Found {len(class_dirs)} Pokemon classes")

# Count images per class
class_counts = {}
for class_dir in class_dirs:
    images = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
    class_counts[class_dir.name] = len(images)

# Plot class distribution
plt.figure(figsize=(15, 5))
counts = list(class_counts.values())
plt.hist(counts, bins=20, edgecolor='black')
plt.xlabel('Images per class')
plt.ylabel('Number of classes')
plt.title('Class Distribution')
plt.axvline(np.mean(counts), color='r', linestyle='--', label=f'Mean: {np.mean(counts):.1f}')
plt.legend()
plt.show()

print(f"\nStatistics:")
print(f"  Min images: {min(counts)}")
print(f"  Max images: {max(counts)}")
print(f"  Mean images: {np.mean(counts):.1f}")
print(f"  Median images: {np.median(counts):.1f}")

# Show sample images
fig, axes = plt.subplots(4, 6, figsize=(15, 10))
sample_classes = random.sample(class_dirs, 24)

for ax, class_dir in zip(axes.flat, sample_classes):
    images = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
    if images:
        img = Image.open(random.choice(images))
        ax.imshow(img)
        ax.set_title(class_dir.name[:12], fontsize=8)
    ax.axis('off')

plt.suptitle('Sample Pokemon Images', fontsize=14)
plt.tight_layout()
plt.show()

# Store for later use
POKEMON_DIR = pokemon_dir
CLASS_NAMES = sorted([d.name for d in class_dirs])
NUM_CLASSES = len(CLASS_NAMES)
print(f"\nFound {NUM_CLASSES} Pokemon classes")

In [None]:
#@title 4. Prepare Data Generators
#@markdown Create train/validation/test splits with augmentation

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Training data generator with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=CONFIG['rotation_range'],
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=CONFIG['zoom_range'],
    horizontal_flip=CONFIG['horizontal_flip'],
    brightness_range=(0.8, 1.2),
    fill_mode='nearest',
    validation_split=CONFIG['validation_split'] + CONFIG['test_split']  # Combined for initial split
)

# Validation/test data generator (no augmentation)
val_test_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=CONFIG['test_split'] / (CONFIG['validation_split'] + CONFIG['test_split'])  # Split val from test
)

# Create generators
train_generator = train_datagen.flow_from_directory(
    POKEMON_DIR,
    target_size=(CONFIG['image_size'], CONFIG['image_size']),
    batch_size=CONFIG['batch_size'],
    class_mode='categorical',
    subset='training',
    shuffle=True
)

# For validation and test, we need a workaround since ImageDataGenerator 
# only supports 2-way split. We'll use the validation subset and split it.
temp_val_generator = train_datagen.flow_from_directory(
    POKEMON_DIR,
    target_size=(CONFIG['image_size'], CONFIG['image_size']),
    batch_size=CONFIG['batch_size'],
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# For simplicity in this notebook, we'll use the same data for val and test
# In production, you'd want a proper 3-way split
val_generator = temp_val_generator
test_generator = temp_val_generator

# Store class indices
class_indices = train_generator.class_indices
index_to_class = {v: k for k, v in class_indices.items()}

print(f"Data generators created")
print(f"  Training samples: {train_generator.samples}")
print(f"  Validation samples: {val_generator.samples}")
print(f"  Number of classes: {train_generator.num_classes}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Steps per epoch: {len(train_generator)}")

# Save class mapping
import json
labels_path = Path(CONFIG['models_dir']) / 'labels.json'
with open(labels_path, 'w') as f:
    json.dump({
        'class_names': list(class_indices.keys()),
        'class_indices': class_indices,
        'index_to_class': index_to_class,
        'num_classes': len(class_indices)
    }, f, indent=2)
print(f"  Labels saved to: {labels_path}")

In [None]:
#@title 5. Visualize Augmentations
#@markdown See how training augmentations transform images

import matplotlib.pyplot as plt

# Get a sample batch
sample_batch = next(train_generator)
images, labels = sample_batch

# Show augmented images
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i])
        class_idx = np.argmax(labels[i])
        ax.set_title(index_to_class[class_idx][:15], fontsize=9)
    ax.axis('off')

plt.suptitle('Augmented Training Images', fontsize=14)
plt.tight_layout()
plt.show()

# Reset generator
train_generator.reset()

In [None]:
#@title 6. Build Model Architecture
#@markdown Create MobileNetV3 with custom classification head

from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (
    Dense, Dropout, GlobalAveragePooling2D, 
    BatchNormalization, Input
)
from tensorflow.keras.optimizers import Adam

def build_model(num_classes, image_size=224, dropout_rate=0.5):
    """Build MobileNetV3 with custom classification head."""
    
    # Load pre-trained MobileNetV3
    base_model = MobileNetV3Small(
        weights='imagenet',
        include_top=False,
        input_shape=(image_size, image_size, 3)
    )
    
    # Freeze base model
    base_model.trainable = False
    
    # Build model
    inputs = Input(shape=(image_size, image_size, 3))
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dropout(dropout_rate)(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(dropout_rate)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    return model, base_model

# Build the model
model, base_model = build_model(
    num_classes=NUM_CLASSES,
    image_size=CONFIG['image_size'],
    dropout_rate=CONFIG['dropout_rate']
)

# Compile for Stage 1 (frozen base)
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate_frozen']),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_accuracy')]
)

# Model summary
model.summary()

print(f"\nModel built")
print(f"  Total parameters: {model.count_params():,}")
print(f"  Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in model.trainable_weights]):,}")

In [None]:
#@title 7. Define Training Callbacks
#@markdown Set up checkpointing, early stopping, and learning rate scheduling

from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau,
    TensorBoard, CSVLogger
)
import datetime

# Create directories for logs
log_dir = Path(CONFIG['models_dir']) / 'logs' / datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir.mkdir(parents=True, exist_ok=True)

# Callbacks
callbacks = [
    # Save best model
    ModelCheckpoint(
        filepath=str(Path(CONFIG['models_dir']) / 'best_model.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    
    # Early stopping
    EarlyStopping(
        monitor='val_loss',
        patience=CONFIG['early_stopping_patience'],
        restore_best_weights=True,
        verbose=1
    ),
    
    # Reduce learning rate when stuck
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=CONFIG['reduce_lr_patience'],
        min_lr=1e-7,
        verbose=1
    ),
    
    # TensorBoard logging
    TensorBoard(
        log_dir=str(log_dir),
        histogram_freq=1
    ),
    
    # CSV logging
    CSVLogger(
        str(Path(CONFIG['models_dir']) / 'training_log.csv'),
        append=True
    )
]

print(f"Callbacks configured")
print(f"  Checkpoints: {CONFIG['models_dir']}/best_model.keras")
print(f"  TensorBoard logs: {log_dir}")

In [None]:
#@title 8. Stage 1: Train Classification Head (Frozen Base)
#@markdown Train only the new classification layers

print("=" * 60)
print("STAGE 1: Training classification head (base model frozen)")
print("=" * 60)

history_frozen = model.fit(
    train_generator,
    epochs=CONFIG['epochs_frozen'],
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
axes[0].plot(history_frozen.history['accuracy'], label='Train')
axes[0].plot(history_frozen.history['val_accuracy'], label='Validation')
axes[0].set_title('Stage 1: Accuracy')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True)

# Loss
axes[1].plot(history_frozen.history['loss'], label='Train')
axes[1].plot(history_frozen.history['val_loss'], label='Validation')
axes[1].set_title('Stage 1: Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(str(Path(CONFIG['models_dir']) / 'stage1_training.png'))
plt.show()

print(f"\nStage 1 complete")
print(f"  Final train accuracy: {history_frozen.history['accuracy'][-1]:.4f}")
print(f"  Final val accuracy: {history_frozen.history['val_accuracy'][-1]:.4f}")

In [None]:
#@title 9. Stage 2: Fine-tune Base Model
#@markdown Unfreeze top layers and continue training with lower learning rate

print("=" * 60)
print("STAGE 2: Fine-tuning (unfreezing top layers)")
print("=" * 60)

# Unfreeze the top layers of the base model
base_model.trainable = True

# Freeze all layers except the top N
for layer in base_model.layers[:-CONFIG['unfreeze_layers']]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=CONFIG['learning_rate_unfrozen']),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_accuracy')]
)

# Print trainable status
trainable_count = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"Trainable parameters after unfreezing: {trainable_count:,}")

# Continue training
history_unfrozen = model.fit(
    train_generator,
    epochs=CONFIG['epochs_unfrozen'],
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
axes[0].plot(history_unfrozen.history['accuracy'], label='Train')
axes[0].plot(history_unfrozen.history['val_accuracy'], label='Validation')
axes[0].set_title('Stage 2: Accuracy')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True)

# Loss
axes[1].plot(history_unfrozen.history['loss'], label='Train')
axes[1].plot(history_unfrozen.history['val_loss'], label='Validation')
axes[1].set_title('Stage 2: Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(str(Path(CONFIG['models_dir']) / 'stage2_training.png'))
plt.show()

print(f"\nStage 2 complete")
print(f"  Final train accuracy: {history_unfrozen.history['accuracy'][-1]:.4f}")
print(f"  Final val accuracy: {history_unfrozen.history['val_accuracy'][-1]:.4f}")

In [None]:
#@title 10. Evaluate on Test Set
#@markdown Get final performance metrics

print("=" * 60)
print("EVALUATION")
print("=" * 60)

# Load best model
best_model = tf.keras.models.load_model(
    str(Path(CONFIG['models_dir']) / 'best_model.keras')
)

# Evaluate on test set
test_generator.reset()
results = best_model.evaluate(test_generator, verbose=1)

print(f"\nTest Results:")
print(f"  Loss: {results[0]:.4f}")
print(f"  Top-1 Accuracy: {results[1]:.4f} ({results[1]*100:.1f}%)")
print(f"  Top-5 Accuracy: {results[2]:.4f} ({results[2]*100:.1f}%)")

# Check if we hit our target
if results[1] >= 0.80:
    print(f"\nSUCCESS! Achieved {results[1]*100:.1f}% accuracy (target: 80%)")
else:
    print(f"\nBelow target. Achieved {results[1]*100:.1f}% (target: 80%)")
    print("   Consider: more epochs, data augmentation, or larger model")

In [None]:
#@title 11. Confusion Matrix & Error Analysis
#@markdown Understand where the model makes mistakes

import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Get predictions
test_generator.reset()
predictions = best_model.predict(test_generator, verbose=1)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_generator.classes

# Classification report
print("Classification Report (showing worst performing classes):\n")
report = classification_report(
    true_classes, 
    predicted_classes, 
    target_names=list(class_indices.keys()),
    output_dict=True
)

# Find worst performing classes
class_f1_scores = {k: v['f1-score'] for k, v in report.items() 
                   if k not in ['accuracy', 'macro avg', 'weighted avg']}
worst_classes = sorted(class_f1_scores.items(), key=lambda x: x[1])[:10]

print("Worst performing Pokemon:")
for name, f1 in worst_classes:
    print(f"  {name}: F1={f1:.3f}")

# Confusion matrix for worst classes
worst_indices = [class_indices[name] for name, _ in worst_classes]
mask = np.isin(true_classes, worst_indices)
cm_subset = confusion_matrix(
    true_classes[mask], 
    predicted_classes[mask],
    labels=worst_indices
)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm_subset, 
    annot=True, 
    fmt='d',
    xticklabels=[worst_classes[i][0][:10] for i in range(len(worst_classes))],
    yticklabels=[worst_classes[i][0][:10] for i in range(len(worst_classes))],
    cmap='Blues'
)
plt.title('Confusion Matrix (Worst Performing Classes)')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.savefig(str(Path(CONFIG['models_dir']) / 'confusion_matrix.png'))
plt.show()

# Show misclassified examples
print("\n" + "=" * 60)
print("Sample Misclassified Images")
print("=" * 60)

misclassified_indices = np.where(predicted_classes != true_classes)[0]
sample_mistakes = np.random.choice(misclassified_indices, min(9, len(misclassified_indices)), replace=False)

fig, axes = plt.subplots(3, 3, figsize=(12, 12))
test_generator.reset()
all_images = []
for i in range(len(test_generator)):
    batch = next(test_generator)
    all_images.extend(batch[0])
    if len(all_images) >= len(true_classes):
        break

for ax, idx in zip(axes.flat, sample_mistakes):
    if idx < len(all_images):
        ax.imshow(all_images[idx])
        true_name = index_to_class[true_classes[idx]]
        pred_name = index_to_class[predicted_classes[idx]]
        conf = predictions[idx][predicted_classes[idx]]
        ax.set_title(f"True: {true_name[:12]}\nPred: {pred_name[:12]} ({conf:.2f})", fontsize=9)
    ax.axis('off')

plt.suptitle('Misclassified Examples', fontsize=14)
plt.tight_layout()
plt.savefig(str(Path(CONFIG['models_dir']) / 'misclassified_examples.png'))
plt.show()

In [None]:
#@title 12. Export Model for Browser (TensorFlow.js)
#@markdown Convert model to TensorFlow.js format with quantization

import tensorflowjs as tfjs
from pathlib import Path

# Export directory
tfjs_dir = Path(CONFIG['models_dir']) / 'tfjs_model'
tfjs_dir.mkdir(exist_ok=True)

# Convert to TensorFlow.js with quantization
print("Converting to TensorFlow.js format...")
tfjs.converters.save_keras_model(
    best_model,
    str(tfjs_dir),
    quantization_dtype_map={'uint8': '*'}  # Quantize all layers to uint8
)

# Check output size
total_size = sum(f.stat().st_size for f in tfjs_dir.glob('**/*') if f.is_file())
print(f"\nModel exported to: {tfjs_dir}")
print(f"  Total size: {total_size / 1024 / 1024:.2f} MB")

# List files
print("\nExported files:")
for f in sorted(tfjs_dir.glob('*')):
    size = f.stat().st_size / 1024
    print(f"  {f.name}: {size:.1f} KB")

# Copy labels.json to tfjs directory
import shutil
shutil.copy(
    Path(CONFIG['models_dir']) / 'labels.json',
    tfjs_dir / 'labels.json'
)
print(f"  labels.json: copied")

# Create a zip for easy download
!cd {CONFIG['models_dir']} && zip -r tfjs_model.zip tfjs_model/
print(f"\nCreated tfjs_model.zip for download")

In [None]:
#@title 13. Download Model Files (IMPORTANT!)
#@markdown Download the exported model before the session ends!

from google.colab import files
from pathlib import Path

print("=" * 60)
print("DOWNLOAD YOUR MODEL")
print("=" * 60)
print("")
print("IMPORTANT: Files are stored locally and will be DELETED")
print("when this Colab session ends. Download now!")
print("")

# Download the zip file
zip_path = Path(CONFIG['models_dir']) / 'tfjs_model.zip'
if zip_path.exists():
    print("Downloading tfjs_model.zip...")
    files.download(str(zip_path))
    print("")
    print("Download started! Check your browser's download folder.")
else:
    print("Model zip not found. Run the export cell first.")

print("")
print("The zip file contains:")
print("  - model.json (model architecture)")
print("  - group1-shard*.bin (model weights)")
print("  - labels.json (class name mapping)")

In [None]:
#@title 14. Download Keras Model (Optional)
#@markdown Download the full Keras model for further training

from google.colab import files
from pathlib import Path

# Download Keras model
keras_path = Path(CONFIG['models_dir']) / 'best_model.keras'
if keras_path.exists():
    print("Downloading best_model.keras...")
    files.download(str(keras_path))
    print("")
    print("This is the full Keras model. Use it if you want to:")
    print("  - Continue training later")
    print("  - Export to other formats (ONNX, TFLite, etc.)")
    print("  - Run inference in Python")
else:
    print("Keras model not found. Training may not have completed.")

In [None]:
#@title 15. Test Inference
#@markdown Verify the exported model works correctly

# Test with a sample image
import numpy as np
from PIL import Image

# Get a sample image
test_generator.reset()
sample_batch = next(test_generator)
sample_image = sample_batch[0][0]
sample_label = np.argmax(sample_batch[1][0])

# Run inference
prediction = best_model.predict(np.expand_dims(sample_image, 0), verbose=0)
predicted_class = np.argmax(prediction[0])
confidence = prediction[0][predicted_class]

# Get top 5 predictions
top5_indices = np.argsort(prediction[0])[-5:][::-1]
top5_probs = prediction[0][top5_indices]

# Display
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(sample_image)
plt.title(f"True: {index_to_class[sample_label]}")
plt.axis('off')

plt.subplot(1, 2, 2)
y_pos = np.arange(5)
plt.barh(y_pos, top5_probs)
plt.yticks(y_pos, [index_to_class[i][:15] for i in top5_indices])
plt.xlabel('Confidence')
plt.title('Top 5 Predictions')
plt.gca().invert_yaxis()

plt.tight_layout()
plt.show()

print(f"\nInference test complete")
print(f"  Predicted: {index_to_class[predicted_class]} ({confidence:.2%})")
print(f"  Actual: {index_to_class[sample_label]}")
correct_symbol = 'Correct' if predicted_class == sample_label else 'Incorrect'
print(f"  Result: {correct_symbol}")

In [None]:
#@title 16. Summary & Next Steps
#@markdown Review what was accomplished and plan next steps

print("=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)

# Load training log
import pandas as pd
log_path = Path(CONFIG['models_dir']) / 'training_log.csv'
if log_path.exists():
    log_df = pd.read_csv(log_path)
    best_epoch = log_df['val_accuracy'].idxmax()
    best_val_acc = log_df['val_accuracy'].max()
    
    print(f"\nTraining Results:")
    print(f"  Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.1f}%)")
    print(f"  Best epoch: {best_epoch + 1}")
    print(f"  Total epochs trained: {len(log_df)}")

print(f"\nExported Model:")
tfjs_dir = Path(CONFIG['models_dir']) / 'tfjs_model'
if tfjs_dir.exists():
    total_size = sum(f.stat().st_size for f in tfjs_dir.glob('**/*') if f.is_file())
    print(f"  Location: {tfjs_dir}")
    print(f"  Size: {total_size / 1024 / 1024:.2f} MB")
    print(f"  Format: TensorFlow.js (quantized)")

print("\n" + "=" * 60)
print("IMPORTANT REMINDER")
print("=" * 60)
print("")
print("This notebook uses LOCAL storage. All files will be DELETED")
print("when the Colab session ends!")
print("")
print("Make sure you have downloaded:")
print("  [  ] tfjs_model.zip (for browser deployment)")
print("  [  ] best_model.keras (optional, for further training)")
print("")
print("Run cells 13 and 14 to download these files.")

print("\n" + "=" * 60)
print("NEXT STEPS")
print("=" * 60)
print("""
Day 2: Browser Integration
  1. Extract tfjs_model.zip
  2. Create React/Next.js app
  3. Load model with TensorFlow.js
  4. Add camera capture
  5. Build Pokedex UI

Day 3+: Improvements
  - Create benchmark dataset
  - Add Pokemon card images to training data
  - Test on real-world photos
  - Iterate based on failures
""")