In [None]:
import os
import shutil
import logging
from flask import Flask, request, render_template, redirect, url_for, jsonify,session,flash
from sklearn.model_selection import train_test_split
import numpy as np
from datetime import datetime
from werkzeug.utils import secure_filename 
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from utils.add_class_to_dataset import add_class_to_dataset
from utils.augment_image import augment_image
from utils.generate_class_distribution_plot import generate_class_distribution_plot
from utils.get_class_distribution import get_class_distribution
from utils.remove_class_from_dataset import remove_class_from_dataset
from utils.resample_classes import resample_classes
from utils.split_dataset_by_ratio import split_dataset_by_ratio
from utils.focal_loss import focal_loss
from utils.create_model import create_model
from utils.prepare_image import prepare_image
from utils.get_available_models import get_available_models
from utils.load_model_info import load_model_info
from utils.merge_back_data_if_split import merge_back_data_if_split
from tensorflow.keras.models import Sequential


app = Flask(__name__)

# Set up logging
logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set image parameters
img_height = 224
img_width = 224
batch_size = 32

# Custom loss function
@tf.keras.utils.register_keras_serializable()
def focal_loss_fixed(gamma=2., alpha=0.25):
    def focal_loss(y_true, y_pred):
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        alpha_t = y_true * alpha + (tf.keras.backend.ones(tf.keras.backend.shape(y_true)) - y_true) * (1 - alpha)
        p_t = y_true * y_pred + (tf.keras.backend.ones(tf.keras.backend.shape(y_true)) - y_true) * (1 - y_pred)
        fl = - alpha_t * tf.keras.backend.pow((tf.keras.backend.ones(tf.keras.backend.shape(y_true)) - p_t), gamma) * tf.keras.backend.log(p_t)
        return tf.keras.backend.mean(fl)
    return focal_loss

# Define custom objects dictionary
custom_objects = {
    'focal_loss_fixed': focal_loss_fixed
}



@app.route('/')
def index():
    return render_template('index.html')  # Render the main index page with navigation links
    

@app.route('/dataset', methods=['GET', 'POST'])
def dataset():
    # Initialize variables to hold the response data
    class_distribution = None
    message = None
    plot_url = None
    error = None
    dataset_path = None

    if request.method == 'POST':
        # Check which form was submitted
        if 'set_dataset' in request.form:
            # Handle dataset path form
            dataset_path = request.form.get('dataset_path', '').strip()

            # Check if the dataset path exists
            if not os.path.exists(dataset_path):
                error = "Dataset path does not exist."
            else:
                message = "Dataset path set successfully!"
                class_distribution = get_class_distribution(dataset_path)
                plot_url = generate_class_distribution_plot(class_distribution)

        elif 'perform_operation' in request.form:
            # Handle operations form
            dataset_path = request.form.get('dataset_path_hidden')
            if not dataset_path:
                error = "Please set the dataset path first."
            else:
                try:
                    # Handle dataset operations based on user choice
                    operation = request.form.get('operation')

                    if operation == 'add_class':
                        # Add a new class to the dataset
                        new_class = request.form.get('new_class', '').strip()
                        class_images_dir = request.form.get('class_images_dir', '').strip()
                        if not new_class or not class_images_dir:
                            error = "Both class name and class images directory are required."
                        else:
                            message = add_class_to_dataset(dataset_path, new_class, class_images_dir)
                            class_distribution = get_class_distribution(dataset_path)
                            plot_url = generate_class_distribution_plot(class_distribution)

                    elif operation == 'remove_class':
                        # Remove selected class from the dataset
                        class_to_remove = request.form.get('class_to_remove', '').strip()
                        if not class_to_remove:
                            error = "Class to remove is required."
                        else:
                            message = remove_class_from_dataset(dataset_path, class_to_remove)
                            class_distribution = get_class_distribution(dataset_path)
                            plot_url = generate_class_distribution_plot(class_distribution)

                    elif operation == 'resample':
                        # Resample classes in the dataset
                        target_count_str = request.form.get('target_count')
                        if not target_count_str:
                            error = "Target count is required for resampling."
                        else:
                            target_count = int(target_count_str)
                            resample_classes(dataset_path, target_count)
                            message = "Resampling done!"
                            class_distribution = get_class_distribution(dataset_path)
                            plot_url = generate_class_distribution_plot(class_distribution)

                    elif operation == 'split':
                        # Check if data is already split into train/validation/test
                        if any(os.path.exists(os.path.join(dataset_path, split)) for split in ['train', 'validation', 'test']):
                            user_response = request.form.get('merge_back', '').lower()
                            if user_response == 'yes':
                                # Merge data back
                                merge_back_data_if_split(dataset_path)
                                message = "Data merged back successfully. Ready to re-split."

                        # Split the dataset into train/validation/test sets
                        train_ratio = float(request.form.get('train_ratio', 0))
                        val_ratio = float(request.form.get('validation_ratio', 0))
                        test_ratio = float(request.form.get('test_ratio', 0))
                        if train_ratio + val_ratio + test_ratio > 1:
                            error = "The sum of training, validation, and test ratios must not exceed 1."
                        else:
                            split_dataset_by_ratio(dataset_path, train_ratio, val_ratio, test_ratio)
                            message = "Dataset splitting done!"

                except Exception as e:
                    logging.error(f"An error occurred: {e}")
                    error = str(e)

    # Render the dataset management page for GET requests or if errors occurred
    return render_template('dataset.html', message=message, error=error, class_distribution=class_distribution, plot_url=plot_url, dataset_path=dataset_path)



@app.route('/train', methods=['GET', 'POST'])
def train():
    error = None  # Initialize the error variable
    message = None  # Initialize the message variable
    training_logs = []  # Initialize logs
    
    if request.method == 'POST':  # Handle form submission
        # Get the dataset path from the form input
        dataset_path = request.form.get('dataset_path', '').strip()
        
        # Check if the dataset path exists
        if not os.path.exists(dataset_path):
            error = "Dataset path does not exist."
            return jsonify({'error': error, 'message': message, 'training_logs': training_logs})

        base_model_name = request.form.get('base_model')
        epochs = request.form.get('epochs')
        fine_tune = request.form.get('fine_tune')
        fine_tune_epochs = request.form.get('fine_tune_epochs')
        batch_size = request.form.get('batch_size')  # Get batch size from form input
        model_type = request.form.get('model_type')  # New input for model type (pest or crop)

        # Validate the input
        if not base_model_name or not epochs or not batch_size or not model_type:
            error = "Base model, number of epochs, batch size, and model type are required."
            return jsonify({'error': error, 'message': message, 'training_logs': training_logs})

        try:
            train_dir = os.path.join(dataset_path, 'train')
            validation_dir = os.path.join(dataset_path, 'validation')
            epochs = int(epochs)
            fine_tune_epochs = int(fine_tune_epochs) if fine_tune_epochs else 0
            batch_size = int(batch_size)  # Convert batch size to integer
            
            # Data Generators
            img_height, img_width = 224, 224  # Adjust these dimensions as needed
            train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=20, width_shift_range=0.2,
                                               height_shift_range=0.2, zoom_range=0.2, horizontal_flip=True)
            val_datagen = ImageDataGenerator(rescale=1./255)
            train_generator = train_datagen.flow_from_directory(train_dir, target_size=(img_height, img_width),
                                                                batch_size=batch_size, class_mode='categorical')
            val_generator = val_datagen.flow_from_directory(validation_dir, target_size=(img_height, img_width),
                                                            batch_size=batch_size, class_mode='categorical')

            # Model Creation
            model = create_model(base_model_name, num_classes=len(train_generator.class_indices))

            # Logging callback
            class LoggingCallback(tf.keras.callbacks.Callback):
                def on_epoch_end(self, epoch, logs=None):
                    training_logs.append({
                        "epoch": epoch + 1,
                        "accuracy": logs.get('accuracy'),
                        "val_accuracy": logs.get('val_accuracy'),
                        "loss": logs.get('loss'),
                        "val_loss": logs.get('val_loss')
                    })

            # GPU Configuration
            gpus = tf.config.list_physical_devices('GPU')
            if gpus:
                try:
                    for gpu in gpus:
                        tf.config.experimental.set_memory_growth(gpu, True)
                    message = "Training phase will utilize GPU."
                except RuntimeError as e:
                    message = "Failed to configure GPU. Continuing on CPU. Please be patient!"
            else:
                message = "No GPU found. Training will proceed on CPU."
            
            # Model training
            history = model.fit(
                train_generator,
                validation_data=val_generator,
                epochs=epochs,
                callbacks=[LoggingCallback()]  
            )

            # Fine-tuning phase
            if fine_tune == 'true' and fine_tune_epochs > 0:
                model.trainable = True  # Unfreeze the model for fine-tuning
                model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='categorical_crossentropy', metrics=['accuracy'])
                fine_tune_history = model.fit(
                    train_generator,
                    validation_data=val_generator,
                    epochs=fine_tune_epochs,
                    callbacks=[LoggingCallback()]  
                )

            # Get the best validation accuracy from logs
            final_val_accuracy = training_logs[-1]['val_accuracy']
            formatted_val_accuracy = f"{final_val_accuracy * 100:.2f}"  # Format as percentage

            # Get the current date
            training_date = datetime.now().strftime('%Y-%m-%d')

            # Save the model with accuracy, date, and model type in the filename
            models_folder = 'models'
            # Check if the model type is "pest" or "crop" and set the subfolder accordingly
            if model_type == "pest":
                model_subfolder = os.path.join(models_folder, "pest")
            elif model_type == "crop":
                model_subfolder = os.path.join(models_folder, "crop")
            else:
                model_subfolder = models_folder 
                
            os.makedirs(model_subfolder, exist_ok=True)
            model_name = f"{model_type}_{base_model_name}_{len(train_generator.class_indices)}_classes_{formatted_val_accuracy}_acc_{training_date}.keras"
            
            # Attempt to save the model and log any errors
            try:
                model.save(os.path.join(model_subfolder, model_name))
                logging.info(f"Model saved successfully at: {model_name}")
            except Exception as e:
                logging.error(f"Error saving model: {e}")

            # Prepare model info file path
            model_info_path = os.path.join(model_subfolder, f"{model_type}_{base_model_name}_{len(train_generator.class_indices)}_classes_{formatted_val_accuracy}_acc_{training_date}.txt")
            
            # Attempt to create the model info file and log any errors
            try:
                with open(model_info_path, 'w') as f:
                    f.write(f"Model: {base_model_name}\n")
                    f.write(f"Number of Classes: {len(train_generator.class_indices)}\n")
                    f.write(f"Classes: {list(train_generator.class_indices.keys())}\n")
                    f.write(f"Training Accuracy: {training_logs[-1]['accuracy']:.4f}\n")
                    f.write(f"Validation Accuracy: {training_logs[-1]['val_accuracy']:.4f}\n")
                logging.info(f"Model info file created successfully at: {model_info_path}")
            except Exception as e:
                logging.error(f"Error creating model info file: {e}")

            message = f"Model training completed successfully with validation accuracy of {formatted_val_accuracy}%!"
            return jsonify({'error': error, 'message': message, 'training_logs': training_logs})

        except Exception as e:
            error = f"An error occurred during training: {str(e)}"
            return jsonify({'error': error, 'message': message, 'training_logs': training_logs})

    # Render the training form (for GET requests)
    return render_template('train.html', message=message, error=error, training_logs=training_logs)


@app.route('/inference', methods=['GET', 'POST'])
def inference():
    error = None
    models = []

    # Load available models based on selected type
    if request.method == 'POST':
        model_type = request.form.get('model_type')  # Get the model type from the form
        model_choice = request.form.get('model_choice')  # Get the model choice from the form

        # image upload
        if 'image' not in request.files or request.files['image'].filename == '':
            error = "No image uploaded. Please upload an image."
            return render_template('inference.html', models=models, error=error)

        # Save the uploaded file
        file = request.files['image']
        filename = secure_filename(file.filename)
        file_path = os.path.join('static/uploads', filename)
        file.save(file_path)

        # Prepare the image for prediction
        img = prepare_image(file_path)

        # Load the selected model
        model_path = os.path.join('models', model_type, model_choice)
        try:
            model = tf.keras.models.load_model(model_path)
            logging.info(f"Loaded model: {model_choice}")
        except Exception as e:
            logging.error(f"Error loading model: {e}")
            return f"Error loading model: {e}", 500

        # Load class labels
        try:
            class_labels = load_model_info(model_choice, model_type)  # Pass both arguments
        except Exception as e:
            logging.error(f"Error loading class labels: {e}")
            return f"Error loading class labels: {e}", 500

        # Perform prediction
        try:
            prediction = model.predict(img)
            predicted_class = np.argmax(prediction, axis=1)[0]
            predicted_label = class_labels[predicted_class]
        except Exception as e:
            logging.error(f"Error during prediction: {e}")
            return f"Error during prediction: {e}", 500

        # Render the result page with the predicted label
        return render_template('result.html', label=predicted_label)

    # For GET requests, show available models
    models = get_available_models('pest') + get_available_models('crop')
    return render_template('inference.html', models=models, error=error)
    
@app.route('/get_models/<model_type>', methods=['GET'])
def get_models(model_type):
    models = get_available_models(model_type.lower())
    return jsonify(models=models)

@app.route('/result')
def result():
    return render_template('result.html')

if __name__ == '__main__':
    app.run(port=4269)


 * Serving Flask app '__main__'
 * Debug mode: off
