In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16, ResNet50, MobileNet, InceptionV3, EfficientNetB0
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os

In [2]:
class FishClassifier:
    def __init__(self, input_shape=(224, 224, 3), num_classes=None):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.models = {}
        self.histories = {}

In [6]:
    def setup_data_generators(self, train_dir, valid_dir, test_dir, batch_size=32):
        """Set up data generators with augmentation for training"""
        # Training data generator with augmentation
        self.train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            zoom_range=0.2,
            fill_mode='nearest'
        )
        
        # Validation/Test data generator with only rescaling
        self.valid_datagen = ImageDataGenerator(rescale=1./255)
        
        # Create generators
        self.train_generator = self.train_datagen.flow_from_directory(
            train_dir,
            target_size=self.input_shape[:2],
            batch_size=batch_size,
            class_mode='categorical'
        )
        
        self.valid_generator = self.valid_datagen.flow_from_directory(
            valid_dir,
            target_size=self.input_shape[:2],
            batch_size=batch_size,
            class_mode='categorical'
        )
        
        self.test_generator = self.valid_datagen.flow_from_directory(
            test_dir,
            target_size=self.input_shape[:2],
            batch_size=batch_size,
            class_mode='categorical'
        )
        
        self.num_classes = len(self.train_generator.class_indices)

In [None]:



        
    
    
    def build_custom_cnn(self):
        """Build a custom CNN architecture"""
        model = Sequential([
            Conv2D(32, (3, 3), activation='relu', input_shape=self.input_shape),
            MaxPooling2D(2, 2),
            Conv2D(64, (3, 3), activation='relu'),
            MaxPooling2D(2, 2),
            Conv2D(128, (3, 3), activation='relu'),
            MaxPooling2D(2, 2),
            Flatten(),
            Dense(512, activation='relu'),
            Dropout(0.5),
            Dense(self.num_classes, activation='softmax')
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        
        self.models['custom_cnn'] = model
        return model
    
    def build_transfer_learning_model(self, base_model_name):
        """Build a transfer learning model using pre-trained architectures"""
        base_models = {
            'vgg16': VGG16,
            'resnet50': ResNet50,
            'mobilenet': MobileNet,
            'inception': InceptionV3,
            'efficientnet': EfficientNetB0
        }
        
        # Get the base model
        base_model = base_models[base_model_name](
            weights='imagenet',
            include_top=False,
            input_shape=self.input_shape
        )
        
        # Freeze the base model layers
        base_model.trainable = False
        
        # Add custom layers
        x = Flatten()(base_model.output)
        x = Dense(512, activation='relu')(x)
        x = Dropout(0.5)(x)
        outputs = Dense(self.num_classes, activation='softmax')(x)
        
        model = Model(base_model.input, outputs)
        
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        
        self.models[base_model_name] = model
        return model
    
    def train_model(self, model_name, epochs=20):
        """Train the specified model"""
        model = self.models[model_name]
        
        # Set up callbacks
        checkpoint = ModelCheckpoint(
            f'best_{model_name}.h5',
            monitor='val_accuracy',
            save_best_only=True
        )
        early_stopping = EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True
        )
        
        # Train the model
        history = model.fit(
            self.train_generator,
            epochs=epochs,
            validation_data=self.valid_generator,
            callbacks=[checkpoint, early_stopping]
        )
        
        self.histories[model_name] = history
        return history
    
    def evaluate_model(self, model_name):
        """Evaluate the model and return metrics"""
        model = self.models[model_name]
        
        # Get predictions
        predictions = model.predict(self.test_generator)
        y_pred = np.argmax(predictions, axis=1)
        y_true = self.test_generator.classes
        
        # Calculate metrics
        report = classification_report(y_true, y_pred, target_names=self.test_generator.class_indices.keys())
        conf_matrix = confusion_matrix(y_true, y_pred)
        
        return {
            'classification_report': report,
            'confusion_matrix': conf_matrix,
            'history': self.histories[model_name].history
        }
    
    def plot_training_history(self, model_name):
        """Plot training history for the specified model"""
        history = self.histories[model_name].history
        
        plt.figure(figsize=(12, 4))
        
        # Plot accuracy
        plt.subplot(1, 2, 1)
        plt.plot(history['accuracy'], label='Training Accuracy')
        plt.plot(history['val_accuracy'], label='Validation Accuracy')
        plt.title(f'{model_name} - Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        # Plot loss
        plt.subplot(1, 2, 2)
        plt.plot(history['loss'], label='Training Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title(f'{model_name} - Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.tight_layout()
        plt.show()