In [None]:
from ultralytics import YOLO
import os
import cv2
from tqdm import tqdm
import json
from matplotlib import pyplot as plt

# Set Paths for Dataset
dataset_path = "dataset/data.yaml"  # Update with your dataset's path

# Load Pretrained YOLOv8 Model
model = YOLO('yolov8n.pt')  # Choose 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', etc., based on your compute capacity

# 1. Real-Time Loss Visualization
class TrainingVisualizer:
    def __init__(self):
        self.epoch_loss = []

    def on_epoch_end(self, epoch, logs):
        loss = logs['train_loss']  # Logs contain training loss for the epoch
        self.epoch_loss.append(loss)

        # Update live plot
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(self.epoch_loss) + 1), self.epoch_loss, label="Train Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training Loss Over Epochs")
        plt.legend()
        plt.pause(0.1)  # Pause to update dynamically
        plt.show()


# Instantiate the visualizer
visualizer = TrainingVisualizer()

# Train the Model
results = model.train(
    data=dataset_path,       # Path to the dataset YAML file
    epochs=5,                # Number of epochs (increase for better results)
    batch=16,                # Batch size (adjust based on GPU memory)
    imgsz=1280,              # Image size (higher can increase accuracy but requires more memory)
    workers=4,               # Number of data loading workers
    optimizer='AdamW',       # Use AdamW optimizer for better convergence
    lr0=0.001,               # Initial learning rate
    patience=10,             # Early stopping patience
    augment=True,            # Apply data augmentation (random scaling, flipping, etc.)
    val=True,                # Validate after training
    callbacks=[visualizer]   # Add visualizer callback
)

# Save the best model weights
best_model_path = 'chair_detection/yolov8_chair/weights/best.pt'
print(f"Training complete. Best model saved at {best_model_path}")

# Load the trained YOLOv8 model
MODEL_PATH = "chair_detection/yolov8_chair/weights/best.pt"
model = YOLO(MODEL_PATH)


# Function to detect and count chairs (and annotate images)
def detect_and_count_chairs(image, model, class_name="chair"):
    results = model(image)  # Run detection
    detections = results[0].boxes  # Extract bounding boxes
    chair_count = 0

    # Iterate over detections
    for detection in detections:
        class_id = int(detection.cls[0])  # Class ID
        if model.names[class_id] == class_name:  # Match the target class
            chair_count += 1
            # Draw bounding box and label
            x1, y1, x2, y2 = map(int, detection.xyxy[0])  # Bounding box
            conf = detection.conf[0]  # Confidence score
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(
                image,
                f"{class_name} {conf:.2f}",
                (x1, y1 - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (0, 255, 0),
                2,
            )
    return image, chair_count


# Main function to process and save images with results summary
def process_images(input_dir, output_dir, model, class_name="chair"):
    results_summary = []
    supported_formats = (".jpg", ".png", ".jpeg")

    # Process each image
    for img_file in tqdm(os.listdir(input_dir), desc="Processing images"):
        if not img_file.endswith(supported_formats):
            continue

        img_path = os.path.join(input_dir, img_file)
        try:
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError("Invalid image format or file.")

            # Detect and count chairs
            annotated_image, count = detect_and_count_chairs(image, model, class_name)

            # Save the annotated image
            output_path = os.path.join(output_dir, img_file)
            cv2.imwrite(output_path, annotated_image)

            # Append results for this image
            results_summary.append({"image": img_file, "chair_count": count})

        except Exception as e:
            print(f"Error processing {img_file}: {e}")

    # Save results summary to a JSON file
    summary_path = os.path.join(output_dir, "results_summary.json")
    with open(summary_path, "w") as json_file:
        json.dump(results_summary, json_file, indent=4)

    print(f"Processing complete. Results saved to {summary_path}")


# Live Detection Display
def live_detection_display(input_dir, model, class_name="chair"):
    supported_formats = (".jpg", ".png", ".jpeg")
    window_name = "Live Detection Results"

    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, 800, 600)

    for img_file in tqdm(os.listdir(input_dir), desc="Live Detection"):
        if not img_file.endswith(supported_formats):
            continue

        img_path = os.path.join(input_dir, img_file)
        try:
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError("Invalid image format or file.")

            # Detect and count chairs
            annotated_image, count = detect_and_count_chairs(image, model, class_name)

            # Display annotated image
            cv2.imshow(window_name, annotated_image)
            if cv2.waitKey(500) & 0xFF == ord('q'):  # Press 'q' to exit
                break

        except Exception as e:
            print(f"Error processing {img_file}: {e}")

    cv2.destroyAllWindows()


# Save Detection as Video
def save_detection_video(input_dir, output_video_path, model, class_name="chair"):
    supported_formats = (".jpg", ".png", ".jpeg")
    frame_rate = 10  # Adjust as needed

    # Get first image to determine frame size
    first_img = next(
        (os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(supported_formats)),
        None,
    )
    if not first_img:
        raise ValueError("No valid image files in directory.")
    
    first_frame = cv2.imread(first_img)
    height, width, _ = first_frame.shape

    # Define video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))

    for img_file in tqdm(os.listdir(input_dir), desc="Creating Video"):
        if not img_file.endswith(supported_formats):
            continue

        img_path = os.path.join(input_dir, img_file)
        try:
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError("Invalid image format or file.")

            # Detect and annotate image
            annotated_image, count = detect_and_count_chairs(image, model, class_name)

            # Write frame to video
            video_writer.write(annotated_image)

        except Exception as e:
            print(f"Error processing {img_file}: {e}")

    video_writer.release()
    print(f"Video saved at {output_video_path}")


# Input and output directories
INPUT_IMAGES_DIR = "path/to/your/images"  # Replace with your image directory
OUTPUT_IMAGES_DIR = "path/to/save/detections"  # Replace with output directory
os.makedirs(OUTPUT_IMAGES_DIR, exist_ok=True)

# Process and save annotated images
process_images(INPUT_IMAGES_DIR, OUTPUT_IMAGES_DIR, model)

# Live display of detections
live_detection_display(INPUT_IMAGES_DIR, model)

# Save detections as video
save_detection_video(INPUT_IMAGES_DIR, "detections_output.mp4", model)
