# Coffee Bean Annotation Visualizer

This notebook visualizes YOLO format annotations with bounding boxes on coffee bean images.

## Setup Instructions

1. Upload your dataset to Google Colab or mount Google Drive
2. Run the cells below to visualize your annotations

## Step 1: Install Dependencies

In [None]:
# Install required packages (if not already installed)
!pip install opencv-python-headless matplotlib numpy

## Step 2: Upload Dataset (Choose One Method)

### Method A: Upload from Local Computer

In [None]:
# Upload the kaggle_dataset_round1.zip file
from google.colab import files
uploaded = files.upload()

# Unzip the dataset
!unzip -q kaggle_dataset_round1.zip
print("Dataset uploaded and extracted!")

### Method B: Mount Google Drive (if dataset is in Drive)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Update this path to your dataset location in Google Drive
import os
os.chdir('/content/drive/MyDrive/CoffeeBeanDataset')

## Step 3: Define Visualization Functions

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import random

# Class definitions
CLASSES = {
    0: "barely-riped",
    1: "over-riped",
    2: "riped",
    3: "semi-riped",
    4: "unriped"
}

# Colors for each class (BGR format for OpenCV)
COLORS = {
    0: (71, 99, 255),    # Orange/Red for barely-riped
    1: (44, 44, 44),     # Dark gray for over-riped
    2: (157, 166, 196),  # Brown/Pink for riped
    3: (61, 217, 255),   # Yellow for semi-riped
    4: (144, 238, 144)   # Light green for unriped
}

def load_yolo_annotation(label_path):
    """Load YOLO format annotations from file"""
    annotations = []
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 5:
                class_id = int(parts[0])
                x_center = float(parts[1])
                y_center = float(parts[2])
                width = float(parts[3])
                height = float(parts[4])
                annotations.append({
                    'class_id': class_id,
                    'x_center': x_center,
                    'y_center': y_center,
                    'width': width,
                    'height': height
                })
    return annotations

def yolo_to_corners(x_center, y_center, width, height, img_width, img_height):
    """Convert YOLO format to corner coordinates"""
    x_center_px = x_center * img_width
    y_center_px = y_center * img_height
    width_px = width * img_width
    height_px = height * img_height
    
    x1 = int(x_center_px - width_px / 2)
    y1 = int(y_center_px - height_px / 2)
    x2 = int(x_center_px + width_px / 2)
    y2 = int(y_center_px + height_px / 2)
    
    return x1, y1, x2, y2

def draw_bbox_cv2(image, annotations, show_labels=True, thickness=2):
    """Draw bounding boxes using OpenCV"""
    img = image.copy()
    img_height, img_width = img.shape[:2]
    
    for ann in annotations:
        class_id = ann['class_id']
        x1, y1, x2, y2 = yolo_to_corners(
            ann['x_center'], ann['y_center'],
            ann['width'], ann['height'],
            img_width, img_height
        )
        
        # Draw rectangle
        color = COLORS.get(class_id, (255, 255, 255))
        cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
        
        # Draw label
        if show_labels:
            label = CLASSES.get(class_id, f"Class {class_id}")
            label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
            
            # Draw background for text
            cv2.rectangle(img,
                         (x1, y1 - label_size[1] - 4),
                         (x1 + label_size[0], y1),
                         color, -1)
            
            # Draw text
            cv2.putText(img, label, (x1, y1 - 2),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
    
    return img

def visualize_image(image_path, label_path, figsize=(15, 10), show_labels=True):
    """Visualize a single image with annotations"""
    # Read image
    img = cv2.imread(str(image_path))
    if img is None:
        print(f"Error: Could not read image {image_path}")
        return
    
    # Convert BGR to RGB for matplotlib
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Load annotations
    if Path(label_path).exists():
        annotations = load_yolo_annotation(label_path)
    else:
        print(f"Warning: No label file found at {label_path}")
        annotations = []
    
    # Draw bounding boxes
    img_with_boxes = draw_bbox_cv2(img, annotations, show_labels=show_labels)
    img_with_boxes_rgb = cv2.cvtColor(img_with_boxes, cv2.COLOR_BGR2RGB)
    
    # Plot
    plt.figure(figsize=figsize)
    plt.imshow(img_with_boxes_rgb)
    plt.axis('off')
    plt.title(f"{Path(image_path).name} - {len(annotations)} beans detected",
              fontsize=14, pad=10)
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    class_counts = {}
    for ann in annotations:
        class_id = ann['class_id']
        class_name = CLASSES.get(class_id, f"Class {class_id}")
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    
    print(f"\nðŸ“Š Bean counts:")
    for class_name, count in sorted(class_counts.items()):
        print(f"   {class_name}: {count}")
    print(f"   Total: {len(annotations)}")

def visualize_grid(image_dir, label_dir, num_images=4, cols=2, figsize=(15, 15),
                   show_labels=True, random_sample=True):
    """Visualize multiple images in a grid"""
    image_dir = Path(image_dir)
    label_dir = Path(label_dir)
    
    # Get all image files
    image_files = list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png"))
    
    if random_sample:
        image_files = random.sample(image_files, min(num_images, len(image_files)))
    else:
        image_files = image_files[:num_images]
    
    rows = (len(image_files) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    
    if rows == 1 and cols == 1:
        axes = [[axes]]
    elif rows == 1 or cols == 1:
        axes = axes.reshape(rows, cols)
    
    for idx, image_path in enumerate(image_files):
        row = idx // cols
        col = idx % cols
        ax = axes[row][col]
        
        # Read image
        img = cv2.imread(str(image_path))
        if img is None:
            continue
        
        # Load annotations
        label_path = label_dir / f"{image_path.stem}.txt"
        if label_path.exists():
            annotations = load_yolo_annotation(label_path)
        else:
            annotations = []
        
        # Draw bounding boxes
        img_with_boxes = draw_bbox_cv2(img, annotations, show_labels=show_labels)
        img_with_boxes_rgb = cv2.cvtColor(img_with_boxes, cv2.COLOR_BGR2RGB)
        
        # Display
        ax.imshow(img_with_boxes_rgb)
        ax.axis('off')
        ax.set_title(f"{image_path.name}\n{len(annotations)} beans", fontsize=10)
    
    # Hide empty subplots
    for idx in range(len(image_files), rows * cols):
        row = idx // cols
        col = idx % cols
        axes[row][col].axis('off')
    
    plt.tight_layout()
    plt.show()

def show_class_legend():
    """Display color legend for classes"""
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.axis('off')
    
    colors_rgb = {
        0: (255, 99, 71),
        1: (44, 44, 44),
        2: (196, 166, 157),
        3: (255, 217, 61),
        4: (144, 238, 144)
    }
    
    y_pos = 0.8
    for class_id in sorted(CLASSES.keys()):
        class_name = CLASSES[class_id]
        color = tuple(c/255 for c in colors_rgb[class_id])
        
        # Draw colored box
        rect = plt.Rectangle((0.1, y_pos - 0.08), 0.1, 0.15,
                             facecolor=color, edgecolor='black', linewidth=2)
        ax.add_patch(rect)
        
        # Add text
        ax.text(0.25, y_pos, class_name, fontsize=14, va='center')
        
        y_pos -= 0.2
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    plt.title("Coffee Bean Ripeness Classes", fontsize=16, pad=20)
    plt.tight_layout()
    plt.show()

print("âœ… Visualization functions loaded!")

## Step 4: Show Class Legend

In [None]:
show_class_legend()

## Step 5: Visualize Single Image

In [None]:
# Visualize one training image
visualize_image(
    image_path='kaggle_dataset_round1/images/train/020a77e7-106.jpg',
    label_path='kaggle_dataset_round1/labels/train/020a77e7-106.txt',
    figsize=(15, 10),
    show_labels=True
)

## Step 6: Visualize Multiple Images in Grid

In [None]:
# Visualize 6 random training images in a 2x3 grid
visualize_grid(
    image_dir='kaggle_dataset_round1/images/train',
    label_dir='kaggle_dataset_round1/labels/train',
    num_images=6,
    cols=3,
    figsize=(20, 12),
    show_labels=True,
    random_sample=True
)

## Step 7: Visualize Validation Images

In [None]:
# Visualize all validation images
visualize_grid(
    image_dir='kaggle_dataset_round1/images/val',
    label_dir='kaggle_dataset_round1/labels/val',
    num_images=5,
    cols=3,
    figsize=(20, 10),
    show_labels=True,
    random_sample=False
)

## Step 8: Browse Specific Images

List all available images and choose which one to visualize

In [None]:
# List all training images
from pathlib import Path

train_images = sorted(Path('kaggle_dataset_round1/images/train').glob('*.jpg'))
print(f"Found {len(train_images)} training images:\n")
for idx, img in enumerate(train_images, 1):
    print(f"{idx:2d}. {img.name}")

In [None]:
# Choose an image by index (change the number below)
image_index = 0  # Change this to view different images (0 to len-1)

selected_image = train_images[image_index]
label_path = Path('kaggle_dataset_round1/labels/train') / f"{selected_image.stem}.txt"

visualize_image(
    image_path=selected_image,
    label_path=label_path,
    figsize=(15, 10),
    show_labels=True
)

## Step 9: Visualize Specific Classes

Find and visualize images containing specific bean classes

In [None]:
# Find images with rare classes (over-riped or barely-riped)
def find_images_with_class(label_dir, target_class_id):
    """Find all images containing a specific class"""
    label_dir = Path(label_dir)
    matching_files = []
    
    for label_file in label_dir.glob('*.txt'):
        annotations = load_yolo_annotation(label_file)
        for ann in annotations:
            if ann['class_id'] == target_class_id:
                matching_files.append(label_file.stem)
                break
    
    return matching_files

# Find images with over-riped beans (class 1)
over_riped_images = find_images_with_class('kaggle_dataset_round1/labels/train', 1)
print(f"Found {len(over_riped_images)} images with over-riped beans:")
for img in over_riped_images:
    print(f"  - {img}")

# Find images with barely-riped beans (class 0)
barely_riped_images = find_images_with_class('kaggle_dataset_round1/labels/train', 0)
print(f"\nFound {len(barely_riped_images)} images with barely-riped beans:")
for img in barely_riped_images:
    print(f"  - {img}")

In [None]:
# Visualize images with over-riped beans
if over_riped_images:
    img_name = over_riped_images[0]
    visualize_image(
        image_path=f'kaggle_dataset_round1/images/train/{img_name}.jpg',
        label_path=f'kaggle_dataset_round1/labels/train/{img_name}.txt',
        figsize=(15, 10),
        show_labels=True
    )