# Fish Classification - Data Preprocessing & Augmentation

This notebook handles:
1. Image preprocessing (resizing, normalization)
2. Data augmentation strategies
3. Class imbalance handling with class weights
4. Data generators setup for training

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils.class_weight import compute_class_weight

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

In [None]:
BASE_DIR = os.path.abspath('.')
MODELS_DIR = os.path.join(BASE_DIR, 'models', 'cnn')
FIGURES_DIR = os.path.join(BASE_DIR, 'reports', 'figures')

config_path = os.path.join(MODELS_DIR, 'config.json')
with open(config_path, 'r') as f:
    config = json.load(f)

TRAIN_DIR = config['train_dir']
VAL_DIR = config['val_dir']
TEST_DIR = config['test_dir']
NUM_CLASSES = config['num_classes']
CLASSES = config['classes']
IMG_SIZE = tuple(config['image_size'])
BATCH_SIZE = config['batch_size']

print(f"Image size: {IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Number of classes: {NUM_CLASSES}")

## 1. Compute Class Weights for Imbalance

In [None]:
class_counts = config['class_counts_train']

total_samples = sum(class_counts.values())
class_weights = {}
for idx, class_name in enumerate(CLASSES):
    weight = total_samples / (NUM_CLASSES * class_counts[class_name])
    class_weights[idx] = weight

print("Class Weights (for handling imbalance):")
print("=" * 50)
for idx, class_name in enumerate(CLASSES):
    print(f"{class_name}: {class_weights[idx]:.3f} (samples: {class_counts[class_name]})")

config['class_weights'] = {str(k): float(v) for k, v in class_weights.items()}
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

## 2. Data Augmentation Strategy

In [None]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    brightness_range=[0.8, 1.2]
)

val_test_datagen = ImageDataGenerator(
    rescale=1./255
)

print("Data Augmentation Configuration:")
print("=" * 50)
print("Training augmentations:")
print("  - Rotation: ¬±20¬∞")
print("  - Width/Height shift: ¬±20%")
print("  - Shear: 20%")
print("  - Zoom: ¬±20%")
print("  - Horizontal flip: Yes")
print("  - Brightness: [0.8, 1.2]")
print("\nValidation/Test: Only rescaling (no augmentation)")

In [None]:
train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=42
)

val_generator = val_test_datagen.flow_from_directory(
    VAL_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = val_test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

print(f"\nTraining samples: {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples: {test_generator.samples}")
print(f"\nClass indices: {train_generator.class_indices}")

## 3. Visualize Augmented Images

In [None]:
def visualize_augmentations(image_path, datagen, n_augmentations=8):
    """Visualize original image and its augmented versions."""
    img = Image.open(image_path)
    img = img.resize(IMG_SIZE)
    img_array = np.array(img)
    img_array = img_array.reshape((1,) + img_array.shape)
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()
    
    axes[0].imshow(img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    i = 1
    for batch in datagen.flow(img_array, batch_size=1):
        axes[i].imshow(batch[0])
        axes[i].set_title(f'Augmented {i}')
        axes[i].axis('off')
        i += 1
        if i >= 10:
            break
    
    plt.suptitle('Data Augmentation Examples', fontsize=14)
    plt.tight_layout()
    return fig

sample_class = CLASSES[0]
sample_img_name = os.listdir(os.path.join(TRAIN_DIR, sample_class))[0]
sample_img_path = os.path.join(TRAIN_DIR, sample_class, sample_img_name)

augment_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    brightness_range=[0.8, 1.2]
)

fig = visualize_augmentations(sample_img_path, augment_datagen)
plt.savefig(os.path.join(FIGURES_DIR, '04_augmentation_examples.png'), dpi=150, bbox_inches='tight')
plt.show()

## 4. Visualize Sample Batches

In [None]:
def show_batch(generator, class_names, n_images=16):
    """Display a batch of images from the generator."""
    images, labels = next(generator)
    
    n_cols = 4
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 3))
    axes = axes.flatten()
    
    for i in range(min(n_images, len(images))):
        axes[i].imshow(images[i])
        label_idx = np.argmax(labels[i])
        axes[i].set_title(class_names[label_idx], fontsize=9)
        axes[i].axis('off')
    
    for i in range(min(n_images, len(images)), len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Sample Training Batch (with augmentation)', fontsize=14)
    plt.tight_layout()
    return fig

class_names = list(train_generator.class_indices.keys())
fig = show_batch(train_generator, class_names)
plt.savefig(os.path.join(FIGURES_DIR, '05_sample_batch.png'), dpi=150, bbox_inches='tight')
plt.show()

## 5. Create Transfer Learning Preprocessing Functions

In [None]:
def create_generators_for_model(model_name, train_dir, val_dir, test_dir, img_size, batch_size):
    """
    Create data generators with model-specific preprocessing.
    """
    from tensorflow.keras.applications import resnet50, efficientnet, mobilenet_v2, inception_v3, vgg16
    
    preprocessing_functions = {
        'resnet50': resnet50.preprocess_input,
        'efficientnet': efficientnet.preprocess_input,
        'mobilenet': mobilenet_v2.preprocess_input,
        'inception': inception_v3.preprocess_input,
        'vgg16': vgg16.preprocess_input,
        'custom': None
    }
    
    preprocess_fn = preprocessing_functions.get(model_name.lower())
    
    if preprocess_fn:
        train_datagen = ImageDataGenerator(
            preprocessing_function=preprocess_fn,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )
        val_test_datagen = ImageDataGenerator(preprocessing_function=preprocess_fn)
    else:
        train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest',
            brightness_range=[0.8, 1.2]
        )
        val_test_datagen = ImageDataGenerator(rescale=1./255)
    
    train_gen = train_datagen.flow_from_directory(
        train_dir, target_size=img_size, batch_size=batch_size,
        class_mode='categorical', shuffle=True, seed=42
    )
    
    val_gen = val_test_datagen.flow_from_directory(
        val_dir, target_size=img_size, batch_size=batch_size,
        class_mode='categorical', shuffle=False
    )
    
    test_gen = val_test_datagen.flow_from_directory(
        test_dir, target_size=img_size, batch_size=batch_size,
        class_mode='categorical', shuffle=False
    )
    
    return train_gen, val_gen, test_gen

print("Generator factory function created!")
print("Supported models: ResNet50, EfficientNet, MobileNet, Inception, VGG16, Custom")

## 6. Save Preprocessing Configuration

In [None]:
preprocessing_config = {
    'image_size': list(IMG_SIZE),
    'batch_size': BATCH_SIZE,
    'augmentation': {
        'rotation_range': 20,
        'width_shift_range': 0.2,
        'height_shift_range': 0.2,
        'shear_range': 0.2,
        'zoom_range': 0.2,
        'horizontal_flip': True,
        'brightness_range': [0.8, 1.2]
    },
    'class_weights': {str(k): float(v) for k, v in class_weights.items()},
    'class_indices': train_generator.class_indices
}

preprocess_config_path = os.path.join(MODELS_DIR, 'preprocessing_config.json')
with open(preprocess_config_path, 'w') as f:
    json.dump(preprocessing_config, f, indent=2)

print(f"Preprocessing configuration saved to: {preprocess_config_path}")

In [None]:
print("\n" + "="*60)
print("DATA PREPROCESSING SUMMARY")
print("="*60)
print(f"\nüìê Image Configuration:")
print(f"   - Target size: {IMG_SIZE[0]}x{IMG_SIZE[1]}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"\nüîÑ Data Augmentation:")
print(f"   - Rotation, shifts, shear, zoom, flip")
print(f"   - Brightness adjustment")
print(f"\n‚öñÔ∏è Class Imbalance Handling:")
print(f"   - Class weights computed")
print(f"   - Min weight: {min(class_weights.values()):.3f}")
print(f"   - Max weight: {max(class_weights.values()):.3f}")
print(f"\nüì¶ Data Generators:")
print(f"   - Training: {train_generator.samples} samples")
print(f"   - Validation: {val_generator.samples} samples")
print(f"   - Test: {test_generator.samples} samples")
print(f"\n‚úÖ Ready for model training!")
print("="*60)