<a href="https://colab.research.google.com/github/didulabhanuka/colab/blob/main/yoloV8_tomato_ripeness_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Cell 1: Install Required Libraries

In [None]:
!pip install ultralytics albumentations torchvision pycocotools fiftyone
!pip install flask


Cell 2: Import Required Libraries & Mount Google Drive

In [None]:
import os
import cv2
import zipfile
import random
import shutil
from albumentations import Compose, RandomBrightnessContrast, HueSaturationValue, GaussianBlur, MotionBlur, Normalize
from ultralytics import YOLO
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as F
import torch
import numpy as np
from sklearn.metrics import precision_recall_curve, precision_recall_fscore_support
import matplotlib.pyplot as plt
from tqdm import tqdm
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image # Import the Image class from the PIL library
import torchvision.transforms as transforms # Import transforms

from google.colab import drive
drive.mount('/content/drive')


Cell 3: Create YAML File for Dataset Configuration

In [None]:
yaml_content = """
train: /content/dataset_augmented/train
val: /content/dataset_augmented/val

nc: 6
names: ["b_fully_ripened", "b_half_ripened", "b_green", "l_fully_ripened", "l_half_ripened", "l_green"]
"""

yaml_path = "/content/tomato_ripeness_classifier.yaml"
with open(yaml_path, "w") as file:
    file.write(yaml_content)

print(f"YAML file created at {yaml_path}")


Cell 4: Extract Dataset from Google Drive

In [None]:
dataset_zip = "/content/drive/MyDrive/tomato_ripeness_classifier/tomato_dataset.zip"
dataset_dir = "/content/dataset/dataset"

os.makedirs("/content/dataset", exist_ok=True)
with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
    zip_ref.extractall("/content/dataset")

print("Dataset unzipped successfully.")


Cell 5: Data Augmentation

In [None]:
# Define dataset paths
augmented_dir = "/content/dataset_augmented"

# Create Augmented Dataset Folders
os.makedirs(f"{augmented_dir}/train/images", exist_ok=True)
os.makedirs(f"{augmented_dir}/train/labels", exist_ok=True)
os.makedirs(f"{augmented_dir}/val/images", exist_ok=True)
os.makedirs(f"{augmented_dir}/val/labels", exist_ok=True)

def denormalize_image(image):
    """
    Reverts normalization to bring pixel values back to [0,255].
    """
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    # Undo normalization
    image = image * std + mean  # Reverse normalization
    image = np.clip(image * 255, 0, 255).astype(np.uint8)  # Convert to 0-255 range

    return image

def validate_and_clip_bbox(bbox, img_w, img_h):
    """
    Ensures bounding box values stay within valid ranges.
    """
    x_center, y_center, width, height = bbox
    x_center /= img_w
    y_center /= img_h
    width /= img_w
    height /= img_h

    x_center = np.clip(x_center, 0.0, 1.0)
    y_center = np.clip(y_center, 0.0, 1.0)
    width = np.clip(width, 0.0, 1.0)
    height = np.clip(height, 0.0, 1.0)

    x_min = x_center - width / 2
    y_min = y_center - height / 2
    x_max = x_center + width / 2
    y_max = y_center + height / 2

    if 0.0 <= x_min <= 1.0 and 0.0 <= y_min <= 1.0 and 0.0 <= x_max <= 1.0 and 0.0 <= y_max <= 1.0 and width > 0 and height > 0:
        return [x_center, y_center, width, height]
    return None  # Invalid bbox

def advanced_augmentations(image_folder, label_folder, output_image_folder, output_label_folder):
    """
    Applies augmentations while keeping bounding boxes correctly aligned.
    """
    augmentations = Compose(
        [
            RandomBrightnessContrast(p=0.2),
            HueSaturationValue(p=0.2),
            GaussianBlur(p=0.1),
            MotionBlur(p=0.1),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Applied for training
        ],
        bbox_params={"format": "yolo", "label_fields": ["class_labels"]},
    )

    for image_file in os.listdir(image_folder):
        img_path = os.path.join(image_folder, image_file)
        label_path = os.path.join(label_folder, os.path.splitext(image_file)[0] + ".txt")

        # Read the image
        image = cv2.imread(img_path)
        if image is None:
            print(f"Skipping {image_file}: Unable to read image.")
            continue

        h, w, _ = image.shape
        bboxes = []
        class_labels = []

        # Read the bounding boxes
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, x_center, y_center, width, height = map(float, line.strip().split())
                    valid_bbox = validate_and_clip_bbox([x_center * w, y_center * h, width * w, height * h], w, h)
                    if valid_bbox:
                        bboxes.append(valid_bbox)
                        class_labels.append(int(cls))

        if not bboxes:
            print(f"Skipping image {image_file} due to no valid bounding boxes.")
            continue

        # Apply Augmentations
        augmented = augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
        augmented_image = augmented["image"]
        augmented_bboxes = augmented["bboxes"]
        augmented_class_labels = augmented["class_labels"]

        # 🔹 Fix Black Image Issue: Convert Back to uint8 before saving
        augmented_image = denormalize_image(augmented_image)

        # Save Augmented Image
        output_img_path = os.path.join(output_image_folder, image_file)
        cv2.imwrite(output_img_path, augmented_image)

        # Save Updated Labels
        output_label_path = os.path.join(output_label_folder, os.path.splitext(image_file)[0] + ".txt")
        with open(output_label_path, "w") as f:
            for bbox, cls in zip(augmented_bboxes, augmented_class_labels):
                x_center, y_center, width, height = bbox
                f.write(f"{cls} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

        print(f"Saved Augmented Image: {output_img_path}")

# Apply Augmentation to Train and Validation Sets
dataset_dir = "/content/dataset/dataset"
advanced_augmentations(
    image_folder=f"{dataset_dir}/train/images",
    label_folder=f"{dataset_dir}/train/labels",
    output_image_folder=f"{augmented_dir}/train/images",
    output_label_folder=f"{augmented_dir}/train/labels"
)

advanced_augmentations(
    image_folder=f"{dataset_dir}/val/images",
    label_folder=f"{dataset_dir}/val/labels",
    output_image_folder=f"{augmented_dir}/val/images",
    output_label_folder=f"{augmented_dir}/val/labels"
)

print("✅ Augmentation completed successfully.")


Cell 6: Test Augmentation on a Random Image

In [None]:
def check_random_augmented_image(image_folder, label_folder):
    """
    Selects a random augmented image, loads its bounding boxes, and visualizes it.
    """
    image_files = os.listdir(image_folder)
    if not image_files:
        print("No augmented images found!")
        return

    # Select a random image
    random_image = random.choice(image_files)
    img_path = os.path.join(image_folder, random_image)
    label_path = os.path.join(label_folder, os.path.splitext(random_image)[0] + ".txt")

    # Load the image
    image = cv2.imread(img_path)
    if image is None:
        print(f"Error loading image: {random_image}")
        return

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB
    h, w, _ = image.shape  # Get image dimensions

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(image)

    # Load bounding boxes and draw them
    if os.path.exists(label_path):
        with open(label_path, "r") as f:
            for line in f.readlines():
                cls, x_center, y_center, width, height = map(float, line.strip().split())

                # Convert YOLO format to pixel coordinates
                x_center *= w
                y_center *= h
                width *= w
                height *= h

                x_min = x_center - width / 2
                y_min = y_center - height / 2

                rect = plt.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                ax.text(x_min, y_min, f"Class: {int(cls)}", bbox=dict(facecolor='yellow', alpha=0.5))

    plt.title(f"Augmented Image: {random_image}")
    plt.show()

# Run the function to check a random augmented image
check_random_augmented_image(
    image_folder="/content/dataset_augmented/train/images",
    label_folder="/content/dataset_augmented/train/labels"
)


Cell 7: Train YOLOv8

In [None]:
model = YOLO("yolov8n.pt")

model.train(
    data="/content/tomato_ripeness_classifier.yaml",
    epochs=20,
    batch=16,
    imgsz=640,
    project="/content/drive/MyDrive/tomato_ripeness_classifier/",
    name="yolov8_tomato_ripeness",
    workers=4,
    exist_ok=True
)


Cell 8: Evaluate YOLOv8 Model

In [None]:
# Load the trained YOLOv8 model
model = YOLO("/content/drive/MyDrive/tomato_ripeness_classifier/yolov8_tomato_ripeness/weights/best.pt")

# Run evaluation on the validation set
metrics = model.val()

# Extract key metrics
precision = metrics.box.p  # Precision values per class
recall = metrics.box.r      # Recall values per class
map50 = metrics.box.map50  # mAP@50 values per class
map50_95 = metrics.box.map  # mAP@50-95 values per class
classes = metrics.names     # Class names


# Convert metrics to lists
class_labels = list(classes.values())  # Extract class names
num_classes = len(class_labels)

# Plot Precision & Recall
plt.figure(figsize=(10, 5))
plt.bar(class_labels, precision, color="blue", label="Precision")
plt.bar(class_labels, recall, color="red", alpha=0.7, label="Recall")
plt.xlabel("Classes")
plt.ylabel("Score")
plt.title("Precision & Recall per Class")
plt.legend()
plt.xticks(rotation=45)
plt.savefig("/content/precision_recall_plot.png")
plt.show()

# Plot mAP@50 and mAP@50-95
plt.figure(figsize=(10, 5))
plt.bar(class_labels, map50, color="green", label="mAP@50")
plt.bar(class_labels, map50_95, color="orange", alpha=0.7, label="mAP@50-95")
plt.xlabel("Classes")
plt.ylabel("Score")
plt.title("mAP Scores per Class")
plt.legend()
plt.xticks(rotation=45)
plt.savefig("/content/map_plot.png")
plt.show()

print("Evaluation Complete. Saved plots as images.")


Cell 9: Test YOLOv8 on a Random Image

In [None]:
model = YOLO("/content/drive/MyDrive/tomato_ripeness_classifier/yolov8_tomato_ripeness/weights/best.pt")

test_image_folder = "/content/dataset_augmented/val/images"
image_files = os.listdir(test_image_folder)

if not image_files:
    print("No images found in the validation set!")
else:
    random_image = random.choice(image_files)
    image_path = os.path.join(test_image_folder, random_image)

    results = model(image_path, conf=0.5)

    # Access the first (and likely only) Results object in the list
    results = results[0] # Access the first element (Results object)

    results.show()

    output_image_path = f"/content/{random_image}_pred.jpg"
    results.save(filename=output_image_path)

    print(f"Inference complete. Saved result to {output_image_path}.")