# Friendship Sloop Detector Model

## Overview

This notebook demonstrates how to build and train a deep learning model to detect Friendship Sloops in images. The model is based on the ResNet50 architecture and uses data augmentation techniques to improve generalization. The notebook includes steps for data preprocessing, model definition, training, and evaluation.

## Prerequisites
Before running this notebook, ensure you have the following prerequisites:

- Python 3.6 or higher
- TensorFlow 2.x
- Keras
- OpenCV
- Matplotlib
- dotenv
- scipy

## Setup Instructions

### 1. Clone the Repository: 

Clone the repository containing this notebook and navigate to the directory.
Be sure to create the virtual environment.

```bash
git clone <repository_url>
cd <repository_directory>

python3 -m venv venv
source venv/bin/activate
``` 

### 2. Install Dependencies:

Install the required Python packages using pip.

```bash
pip install --upgrade pip
pip install -r requirements.txt
```

### 3. Set Up Environment Variables:

Create a `.env` file in the root directory and define the following environment variables:

```bash
DATA_DIR=path/to/your/data
BATCH_SIZE=32
NUM_EPOCHS=10
MODEL_PATH=path/to/save/model
```

Going forward, use the notebook service that best suits your needs to run this notebook.

## 1. Import Libraries

Imports the necessary libraries for data preprocessing, model building, with metrics and environment management.

In [None]:
from random import randrange
import random
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import Precision, Recall
import cv2
import os
import pprint
from pathlib import Path
from dotenv import load_dotenv

# Suppress only the single warning from urllib3.
import urllib3
urllib3.disable_warnings(category=urllib3.exceptions.InsecureRequestWarning)

# Load environment variables from the .env file (if present)
load_dotenv()

# Define the config class
class CFG:

    # Define the directory to store the images
    DATA_DIR = Path(os.environ['DATA_DIR'])
    
    # Set the number of batchs for processing
    BATCH_SIZE = int(os.environ['BATCH_SIZE'])

    # Epocs for model training
    NUM_EPOCHS = int(os.environ['NUM_EPOCHS'])
    
    MODEL_PATH = Path(os.environ['MODEL_PATH'])
    
pprint.pprint(CFG.__dict__)

## 2. Data Preprocessing

Configures the data augmentation and preprocessing using ImageDataGenerator. It defines training and validation data generators that load images from the specified directory and apply transformations.

In [None]:
# Set up directories
data_dir = CFG.DATA_DIR
img_size = 224 # 128  # Adjust based on image size
batch_size = CFG.BATCH_SIZE

# Data Augmentation
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,  # Split for training/validation
    horizontal_flip=True,
    rotation_range=15,
    zoom_range=0.2,
    shear_range=0.2
)

train_gen = datagen.flow_from_directory(
    data_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='binary',
    subset='training'
)

val_gen = datagen.flow_from_directory(
    data_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='binary',
    subset='validation'
)

## 3. Model Definition
This section defines the model architecture using the ResNet50 base model with custom classification layers. The base model's layers are frozen to prevent them from being trained and adds multiple metrics.

In [None]:
# Load ResNet50 with pretrained weights and exclude top layers
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(img_size, img_size, 3))

# Freeze all layers in the base model
for layer in base_model.layers:
    layer.trainable = False

# Add custom classification layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
print(f'Made: {x}')

# Final model
model = Model(inputs=base_model.input, outputs=x)

# Compile the model with additional metrics
model.compile(optimizer='adam', 
              loss='binary_crossentropy', 
              metrics=['accuracy', Precision(), Recall()])

# # Debug: Show the model summary
# model.summary()

## 4. Show sample of loaded images

Shows a 3x3 grid of the first 9 images loaded from `CFG.DATA_DIR`

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt

# Get a list of image file paths
dir = CFG.DATA_DIR / 'train/friendship_sloop'

# Get a list of image file paths
image_files = [os.path.join(dir, f) for f in os.listdir(dir) if f.endswith('.jpg') or f.endswith('.png')]

# Randomly select 9 images
random_images = random.sample(image_files, 9)

# Display a 3x3 grid of images
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

for img_path, ax in zip(random_images, axes):
    print(f'Processing image: {img_path}')
    img = Image.open(img_path)
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(os.path.basename(img_path))

plt.tight_layout()
plt.show()

## 5. Train the Model

This section trains the model using the training and validation data generators. The training process is monitored using validation accuracy and loss.


In [None]:
history = model.fit(
    train_gen,
    epochs=CFG.NUM_EPOCHS,
    validation_data=val_gen
)
print(f"Training completed.")

## 6. Evaluate and the Model

This section evaluates the model's performance on the `validation` dataset, then graphicly displays in information.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import HTML, display

# Evaluate the model
loss, accuracy, precision, recall = model.evaluate(val_gen)

# Create the metrics object
metrics = {
    'Accuracy': accuracy,
    'Precision': precision,
    'Recall': recall,
    'Loss': loss
}

# Convert metrics to a DataFrame
metrics_df = pd.DataFrame(metrics, index=[0])

# Display metrics as a table
print(metrics_df)

# Plot metrics
plt.figure(figsize=(10, 6))
sns.barplot(data=metrics_df)
plt.title('Model Performance Metrics')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.show()

# Function to generate HTML explanation
def generate_explanation_html(metrics):
    explanation_html = f"""
    <h3>Explanation of Metrics</h3>
    <ol>
        <li><strong>Accuracy</strong>:
            <ul>
                <li><strong>Definition</strong>: The proportion of correctly classified instances out of the total instances.</li>
                <li><strong>Good Value</strong>: Generally, an accuracy above 0.80 (80%) is considered good, but this can vary depending on the complexity of the task and the dataset.</li>
                <li><strong>Current Value</strong>: {metrics['Accuracy']:.2f}</li>
            </ul>
        </li>
        <li><strong>Precision</strong>:
            <ul>
                <li><strong>Definition</strong>: The proportion of true positive predictions out of all positive predictions (i.e., the accuracy of positive predictions).</li>
                <li><strong>Good Value</strong>: A precision above 0.75 (75%) is typically considered good. High precision indicates that the model has a low false positive rate.</li>
                <li><strong>Current Value</strong>: {metrics['Precision']:.2f}</li>
            </ul>
        </li>
        <li><strong>Recall</strong>:
            <ul>
                <li><strong>Definition</strong>: The proportion of true positive predictions out of all actual positives (i.e., the ability of the model to find all relevant instances).</li>
                <li><strong>Good Value</strong>: A recall above 0.75 (75%) is generally considered good. High recall indicates that the model has a low false negative rate.</li>
                <li><strong>Current Value</strong>: {metrics['Recall']:.2f}</li>
            </ul>
        </li>
        <li><strong>Loss</strong>:
            <ul>
                <li><strong>Definition</strong>: A measure of how well the model's predictions match the actual labels. Lower loss values indicate better performance.</li>
                <li><strong>Good Value</strong>: The acceptable loss value depends on the specific loss function used and the problem context. Generally, lower values are better.</li>
                <li><strong>Current Value</strong>: {metrics['Loss']:.2f}</li>
            </ul>
        </li>
    </ol>
    """
    return explanation_html

# Generate and display the explanation using HTML
explanation_html = generate_explanation_html(metrics)
display(HTML(explanation_html))

## 7. Save the model

Saves the model to the `CFG.MODEL_PATH` in the `keras` file format, aka and compressed file with the model contained with in.

In [None]:
# Create the base directory if it doesn't exist
base_dir = CFG.MODEL_PATH.parent
os.makedirs(base_dir, exist_ok=True)

# Save model
model.save(CFG.MODEL_PATH)
print(f'saved { os.path.getsize(CFG.MODEL_PATH) } bytes to {CFG.MODEL_PATH}')

## 8. Model Testing on New Images

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def preprocess_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.resize(img, (img_size, img_size))
    img = img / 255.0
    img = img.reshape(-1, img_size, img_size, 3)
    return img

# Predict
# Directory containing test images
test_image_dir = 'images/prediction'

# Iterate over all images in the directory
for filename in os.listdir(test_image_dir):
    if filename.endswith('.jpg') or filename.endswith('.png'):
        test_image_path = os.path.join(test_image_dir, filename)
        
        # Show the test image
        plt.imshow(Image.open(test_image_path))
        plt.axis('off')
        plt.show()

        # Preprocess and predict
        img = preprocess_image(test_image_path)
        prediction = model.predict(img)

        print(f"Prediction for {filename}: {prediction}")
        if prediction[0][0] > 0.5:
            print("Friendship Sloop detected")
        else:
            print("No Friendship Sloop detected")
