In [2]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import cv2
from PIL import Image, ImageTk
import numpy as np 
import threading
from queue import Queue
from datetime import datetime
import json
import imutils
import os
import inspect

def load_calibration_data():
    # Function to load calibration data
    calibration_file_path = filedialog.askopenfilename(title="Select Calibration Data File", filetypes=[("JSON files", "*.json")])
    if not calibration_file_path:
        print("No file selected")
        return None, None

    try:
        with open(calibration_file_path, 'r') as f:
            calibration_data = json.load(f)
        camera_matrix = np.array(calibration_data['camera_matrix'])
        distortion_coeffs = np.array(calibration_data['distortion_coefficients'])
    except Exception as e:
        print(f"Error loading calibration data: {e}")
        return None, None

    return camera_matrix, distortion_coeffs


class ObjectDetectionGUI:
    def __init__(self, window, window_title):
        self.window = window
        self.window.title(window_title)

        # Style configuration
        self.style = ttk.Style()
        self.style.theme_use('default')

        # Main frame
        main_frame = ttk.Frame(window)
        main_frame.pack(fill=tk.BOTH, expand=True)
        
        # Create a canvas that can fit the above video source size
        self.canvas = tk.Canvas(window, width=800, height=600)
        self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Control frame
        control_frame = ttk.Frame(main_frame)
        control_frame.pack(side=tk.BOTTOM, fill=tk.X)

        # Buttons to control playback
        self.btn_select = tk.Button(control_frame, text="Select Video", width=15, command=self.open_file)
        self.btn_select.pack(side=tk.LEFT, padx=5, pady=5)

        self.btn_play = tk.Button(control_frame, text="Play", width=15, command=self.play_video)
        self.btn_play.pack(side=tk.LEFT, padx=5)

        # Additional buttons for Pause, Resume, Stop, and Capture Frame
        self.btn_pause = tk.Button(control_frame, text="Pause", width=15, command=self.pause_video)
        self.btn_pause.pack(side=tk.LEFT, padx=5)

        self.btn_resume = tk.Button(control_frame, text="Resume", width=15, command=self.resume_video)
        self.btn_resume.pack(side=tk.LEFT, padx=5)

        self.btn_stop = tk.Button(control_frame, text="Stop", width=15, command=self.stop_video)
        self.btn_stop.pack(side=tk.LEFT, padx=5)

        self.btn_capture = tk.Button(control_frame, text="Capture Frame", width=15, command=self.capture_frame)
        self.btn_capture.pack(side=tk.LEFT, padx=5)

        self.btn_live = tk.Button(control_frame, text="Live Feed", width=15, command=self.start_live_feed)
        self.btn_live.pack(side=tk.LEFT, padx=5)

        # Button for selecting images for stitching
        self.btn_stitch_select = tk.Button(control_frame, text="Select Images for Stitching", width=20, command=self.select_images_for_stitching)
        self.btn_stitch_select.pack(side=tk.LEFT, padx=5, pady=5)
    
        # Button to stitch images
        self.btn_stitch = tk.Button(control_frame, text="Stitch Images", width=15, command=self.stitch_images)
        self.btn_stitch.pack(side=tk.LEFT, padx=5, pady=5)

        # Add a button for testing distortion correction
        self.btn_test_distortion = tk.Button(control_frame, text="Test Distortion Correction", width=25, command=self.test_distortion_correction)
        self.btn_test_distortion.pack(side=tk.LEFT, padx=5, pady=5)

        # Additional instance variables to manage playback state
        self.paused = False
        self.stopped = False

        self.delay = 15 
        self.video_source = None
        self.vid = None

        # List to store selected image paths
        self.stitching_images = []

    def select_images_for_stitching(self):
        self.stitching_images = filedialog.askopenfilenames()

    def stitch_images(self):
        # Convert image paths to OpenCV images
        images = [cv2.imread(image_path) for image_path in self.stitching_images]

        # Check if there are at least two images
        if len(images) < 2:
            print("Need at least two images to stitch")
            return

        # Create a stitcher object and stitch the images
        stitcher = cv2.createStitcher() if imutils.is_cv3() else cv2.Stitcher_create()
        status, stitched = stitcher.stitch(images)

        if status == 0:  # Check if the stitching is successful
            # Display the stitched image or save it
            self.display_stitched_image(stitched)
        else:
            messagebox.showerror("Stitching failed. Status code:", status)

    def display_stitched_image(self, stitched_image):
        # Calculate the new size, maintaining aspect ratio
        canvas_width = self.canvas.winfo_width()
        canvas_height = self.canvas.winfo_height()
        original_height, original_width = stitched_image.shape[:2]
        
        # Calculate the aspect ratio of image and canvas
        image_aspect = original_width / original_height
        canvas_aspect = canvas_width / canvas_height
        
        # Determine the appropriate resize scaling factor
        if image_aspect > canvas_aspect:
            # Image is wider than the canvas, so scale by width
            new_width = canvas_width
            new_height = int(canvas_width / image_aspect)
        else:
            # Image is taller than the canvas, so scale by height
            new_height = canvas_height
            new_width = int(canvas_height * image_aspect)
        
        # Resize the image using the new width and height
        resized_image = cv2.resize(stitched_image, (new_width, new_height), interpolation=cv2.INTER_AREA)
        
        # Convert the resized image for Tkinter and display it
        im_pil = Image.fromarray(cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB))
        imgtk = ImageTk.PhotoImage(image=im_pil)
        
        # Update the canvas with the new image
        self.canvas.create_image(0, 0, anchor="nw", image=imgtk)
        self.canvas.image = imgtk  

    def pause_video(self):
        self.paused = True

    def resume_video(self):
        self.paused = False
        self.play_video()

    def stop_video(self):
        self.stopped = True
        if self.vid:
            self.vid.release()
        self.vid = None
        self.canvas.delete("all")
        
    def capture_frame(self):
        if not self.vid:
            return
        ret, frame = self.vid.get_frame()
        if ret:
            # Get the current timestamp and format it to create a unique filename
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S%f")
            filename = f"captured_frame_{timestamp}.png"
            # Save the captured frame to a file with the unique filename
            cv2.imwrite(filename, frame)

    def start_live_feed(self):
        # Stop any existing video playback
        self.stop_video()
        self.paused = False
        self.stopped = False

        # Release previous video source
        if self.vid:
            self.vid.release()
            self.vid = None

        # Open live video source
        self.vid = MyVideoCapture(0)

        # Clear the canvas
        self.canvas.delete("all")

        # Start displaying the live feed
        self.play_video()

        # Start the video capture and detection thread
        self.detection_thread = threading.Thread(target=self.vid.get_frame)
        self.detection_thread.daemon = True
        self.detection_thread.start()

    def open_file(self):
        # Release the previous video capture object if it exists
        self.stop_video()
        self.paused = False
        self.stopped = False
        
        self.video_source = filedialog.askopenfilename()
        if self.video_source: # If a file is selected, create a new video capture object
            self.vid = MyVideoCapture(self.video_source)
            self.play_video()

    def play_video(self):
        if self.stopped or not self.vid:
            return
            
        if self.paused:
            return
        ret, frame = self.vid.get_frame()
        if ret:
            self.photo = ImageTk.PhotoImage(image=Image.fromarray(frame))
            self.canvas.create_image(0, 0, image=self.photo, anchor=tk.NW)
        # Schedule the next frame update
        self.window.after(self.delay, self.play_video) 

    def test_distortion_correction(self):
        if not self.vid:
            messagebox.showerror("Error", "Video source not initialized")
            return
    
        ret, frame = self.vid.get_frame()
        if ret:
            # Show the original image in a new window
            cv2.imshow('Original Image', frame)
    
            # Apply distortion correction if calibration data is available in MyVideoCapture
            if self.vid.camera_matrix is not None and self.vid.distortion_coeffs is not None:
                corrected_frame = cv2.undistort(frame, self.vid.camera_matrix, self.vid.distortion_coeffs, None)
                # Show the corrected image in a new window
                cv2.imshow('Corrected Image', corrected_frame)
            else:
                messagebox.showinfo("Info", "Calibration data not found, showing original image only.")
    
            # Wait for a key press to close the windows
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        else:
            messagebox.showerror("Error", "Failed to capture frame from video source")

class MyVideoCapture:
    def __init__(self, video_source):
        # Open the video source
        self.vid = cv2.VideoCapture(video_source)
        if not self.vid.isOpened():
            raise ValueError("Unable to open video source", video_source)

        self.frame_rate = self.vid.get(cv2.CAP_PROP_FPS)
        self.frame_time = 1 / self.frame_rate if self.frame_rate > 0 else 1 / 30 

        # Get video source width and height
        self.width = self.vid.get(cv2.CAP_PROP_FRAME_WIDTH)
        self.height = self.vid.get(cv2.CAP_PROP_FRAME_HEIGHT)

        # Loading YOLO for object detection
        self.net = cv2.dnn.readNet('yolov4-tiny.weights', 'yolov4-tiny.cfg')
        self.classes = []
        with open("coco.names", "r") as f:
            self.classes = [line.strip() for line in f.readlines()]
            
        self.layer_names = self.net.getLayerNames()
        self.output_layers = [self.layer_names[i - 1] for i in self.net.getUnconnectedOutLayers()]

        self.positions = {} 
        self.previous_positions = {}
        self.object_ids = {}  # Maps a unique ID to each object
        self.next_object_id = 0  # Next ID to assign
        self.speeds = {}  # To store speeds of objects
        self.scale = 10
        self.speed_threshold = 50  
        self.movement_category = "Unknown"
        self.annotations_list = []
        self.frame_buffer = []
        self.buffer_size = 30

        # Load calibration data
        self.camera_matrix, self.distortion_coeffs = load_calibration_data()
           
    def update_buffer(self, frame):
        self.frame_buffer.append(frame)
        if len(self.frame_buffer) > self.buffer_size:
            self.frame_buffer.pop(0)

    def save_segment(self, start_index, end_index, filename):
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(filename, fourcc, self.frame_rate, (int(self.width), int(self.height)))
        for i in range(start_index, end_index):
            out.write(self.frame_buffer[i])
        out.release()

    def compute_centroids(self, boxes):
        return [(int(x + w / 2), int(y + h / 2)) for x, y, w, h in boxes]

    def match_and_update_centroids(self, centroids):
        MATCH_THRESHOLD = 100

        # Update object_ids list to manage IDs
        new_positions = {}
        for centroid in centroids:
            matched = False
            for obj_id, position in self.previous_positions.items():
                if np.linalg.norm(np.array(centroid) - np.array(position)) < MATCH_THRESHOLD:
                    new_positions[obj_id] = centroid
                    matched = True
                    break
            if not matched:
                new_positions[self.next_object_id] = centroid
                self.next_object_id += 1
        self.positions = new_positions

    def get_frame(self):
        if self.vid.isOpened():
            ret, frame = self.vid.read()
            if ret:
                # Apply distortion correction if calibration data is available
                if self.camera_matrix is not None and self.distortion_coeffs is not None:
                    frame = cv2.undistort(frame, self.camera_matrix, self.distortion_coeffs, None)
                    
                # Object Detection
                blob = cv2.dnn.blobFromImage(frame, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
                self.net.setInput(blob)
                outs = self.net.forward(self.output_layers)
    
                # Information to be returned
                class_ids = []
                confidences = []
                boxes = []
    
                for out in outs:
                    for detection in out:
                        scores = detection[5:]
                        class_id = np.argmax(scores)
                        confidence = scores[class_id]
                        # Validate frame dimensions and detection values
                        if frame.shape[1] > 0 and all(0 <= v <= 1 for v in detection[:4]):
                            try:
                                center_x = int(detection[0] * frame.shape[1])
                                center_y = int(detection[1] * frame.shape[0])
                                w = int(detection[2] * frame.shape[1])
                                h = int(detection[3] * frame.shape[0])
                                x = int(center_x - w / 2)
                                y = int(center_y - h / 2)
        
                                boxes.append([x, y, w, h])
                                confidences.append(float(confidence))
                                class_ids.append(class_id)
                            except OverflowError as e:
                                print(f"Overflow error: {e}")
                                continue  # Skip this detection
                        else:
                            # print(f"Invalid frame dimensions or detection values: {frame.shape}, {detection}")
                            continue 
    
                indexes = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)
                centroids = self.compute_centroids([boxes[i] for i in indexes])
                self.match_and_update_centroids(centroids)
    
                for i, index in enumerate(indexes):
                    x, y, w, h = boxes[index]
                    
                    # Check if the index is within the bounds of the centroids list
                    if i < len(centroids):
                        centroid = centroids[i]
                        matched_ids = [id for id, pos in self.positions.items() if pos == centroid]
                        object_id = matched_ids[0] if matched_ids else -1
                    # else:
                    #     object_id = -1  # Placeholder for unmatched objects
                
                    label = str(self.classes[class_ids[index]])
                    color = (0, 255, 0)
                    # Ensure x, y, w, and h are integers
                    x, y, w, h = int(x), int(y), int(w), int(h)
                    cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
                    cv2.putText(frame, f'{label} {object_id}', (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
                
                    self.calculate_and_display_speed(frame, object_id, centroid, x, y, w, h)
    
                self.previous_positions = self.positions.copy()
                self.positions = {}
                return (ret, cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            else:
                return (ret, None)
        else:
            return (False, None)

    # Speed calculation
    def calculate_and_display_speed(self, frame, object_id, centroid, x, y, w, h):
        if object_id not in self.speeds:
            self.speeds[object_id] = []
            
        if object_id in self.previous_positions:
            # Calculate speed
            speed = self.calculate_speed(centroid, object_id)
            formatted_speed = "{:.1f}".format(speed)

            # Display speed on the frame
            speed_text = f"{formatted_speed} km/h"
            text_position = (x, y - 10)  
            if text_position[1] < 0: 
                text_position = (x, y + 20)
    
            cv2.putText(frame, speed_text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
            
            # Determine movement category based on speed
            self.movement_category = "Speeding" if speed > self.speed_threshold else "Moving" if speed > 10 else "Stationary"

            # Speed threshold check and display
            if speed > self.speed_threshold:
                color = (0, 0, 255)  # Red color for over-speeding
                cv2.rectangle(frame, (x, y), (x + w, y + h), color, 3)
                self.log_and_save_speed_info(object_id, formatted_speed, frame)

            # print(f"Speed of object {object_id}: {formatted_speed} km/h, Category: {self.movement_category}")

        else:
            # For new objects, just display them without speed calculation
            color = (0, 255, 0)  # Green color for new objects
            cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
            
        # Update previous positions
        self.previous_positions = self.positions.copy()
        self.positions[object_id] = centroid

    def calculate_speed(self, centroid, object_id):
        MAX_REALISTIC_SPEED = 300  # Maximum realistic speed in km/h
        MOVING_AVERAGE_WINDOW = 5 
        
        if object_id in self.previous_positions:
            # Calculate instantaneous speed
            dx, dy = centroid[0] - self.previous_positions[object_id][0], centroid[1] - self.previous_positions[object_id][1]
            distance_pixels = np.sqrt(dx ** 2 + dy ** 2)
            distance_meters = distance_pixels / self.scale
            instant_speed = (distance_meters / self.frame_time) * 3.6

            if instant_speed > MAX_REALISTIC_SPEED:
                instant_speed = 0  # Resetting unrealistic speed to 0

            # Update the list of the speed for the object
            if object_id not in self.speeds:
                self.speeds[object_id] = []

            self.speeds[object_id].append(instant_speed)

            # Calculate moving average
            if len(self.speeds[object_id]) > MOVING_AVERAGE_WINDOW:
                self.speeds[object_id].pop(0) 

            avg_speed = sum(self.speeds[object_id]) / len(self.speeds[object_id])
        else:
            avg_speed = 0

        return avg_speed

    def log_and_save_speed_info(self, object_id, formatted_speed, frame):
        annotation_info = {
            "timestamp": datetime.now().isoformat(),
            "object_type": self.movement_category,
            "speed": formatted_speed
        }
        self.annotations_list.append(annotation_info)
        cv2.imwrite(f'annotated_frame_{object_id}.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 90])

        # Create JSON Annotation
        json_data = self.create_json_annotation(object_id, formatted_speed, annotation_info)
        self.save_json_annotation(json_data, f'annotation_{object_id}.json')
    
        annotation = f"Speed: {formatted_speed} km/h, Category: {self.movement_category}"
        cv2.putText(frame, annotation, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

    def create_json_annotation(self, object_id, speed, frame_info):
        data = {
            "object_id": object_id,
            "speed": speed,
            "frame_info": frame_info,
        }
        return json.dumps(data)

    def save_json_annotation(self, json_data, filename):
        with open(filename, 'w') as file:
            file.write(json_data)
     
    # Release the video source when the object is destroyed
    def release(self):
        if self.vid.isOpened():
            self.vid.release()
        self.vid = None

root = tk.Tk()
app = ObjectDetectionGUI(root, "Object Detection")
root.mainloop()

No file selected
