In [None]:
import cv2
import torch
import time
import random
import numpy as np

# Load the YOLOv5 model
try:
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
except Exception as e:
    print(f"Error loading YOLOv5 model: {e}")
    exit()

# Path to the traffic video you want to test
video_path = r'traffic_video_test_1.mp4'

# Open the video file
cap = cv2.VideoCapture(video_path)

# Check if the video file opened successfully
if not cap.isOpened():
    print(f"Error: Could not open video file: {video_path}")
    exit()

# Get video frame width and height
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define the codec and create VideoWriter object to save the output video
output_path = r'output_traffic_detection_rl.mp4'
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height))

# Reinforcement Learning setup
states = ['low_traffic', 'medium_traffic', 'high_traffic']  # Simplified states
actions = ['GREEN', 'ORANGE', 'RED']
q_table = np.zeros((len(states), len(actions)))  # Q-table
learning_rate = 0.1
discount_factor = 0.95
epsilon = 0.1  # Exploration rate

# Helper function to determine the state based on vehicle count
def get_state(vehicle_count):
    if vehicle_count < 5:
        return 0  # low_traffic
    elif 5 <= vehicle_count < 15:
        return 1  # medium_traffic
    else:
        return 2  # high_traffic

# Initialize traffic light
current_light = "RED"
current_state = 0
light_timer = time.time()

while True:
    # Capture frame-by-frame
    ret, frame = cap.read()

    if not ret:
        print("End of video reached or failed to grab frame.")
        break

    # Perform YOLOv5 inference on the frame
    results = model(frame)

    # Get the detected objects' bounding boxes and labels
    frame_with_detections = results.render()[0].copy()  # Copy to make the image writable

    # Count vehicles detected in the frame
    vehicle_count = sum(1 for result in results.xyxy[0] if int(result[-1]) == 2)  # '2' is the class ID for car

    # Determine current state
    new_state = get_state(vehicle_count)

    # RL Decision Making
    if random.uniform(0, 1) < epsilon:
        # Explore: Choose a random action
        action = random.choice(actions)
    else:
        # Exploit: Choose the best action based on Q-table
        action = actions[np.argmax(q_table[new_state])]

    # Apply action
    if action == "GREEN":
        if current_light != "GREEN":
            current_light = "GREEN"
            light_timer = time.time()
    elif action == "ORANGE":
        if current_light != "ORANGE":
            current_light = "ORANGE"
            light_timer = time.time()
    elif action == "RED":
        if current_light != "RED":
            current_light = "RED"
            light_timer = time.time()

    # Reward Calculation
    reward = -1  # Default negative reward
    if current_light == "GREEN" and vehicle_count > 10:
        reward = 10  # Reward for reducing congestion
    elif current_light == "RED" and vehicle_count < 5:
        reward = 5  # Reward for efficient light usage

    # Update Q-table
    action_index = actions.index(action)
    q_table[current_state, action_index] = (1 - learning_rate) * q_table[current_state, action_index] + \
                                           learning_rate * (reward + discount_factor * np.max(q_table[new_state]))

    # Update current state
    current_state = new_state

    # Display traffic light status on frame
    cv2.putText(frame_with_detections, f"Light: {current_light}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, 
                (0, 255, 0) if current_light == "GREEN" else (0, 165, 255) if current_light == "ORANGE" else (0, 0, 255), 2)
    cv2.putText(frame_with_detections, f"Vehicles: {vehicle_count}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    # Display the frame with detections
    cv2.imshow('Traffic Video Object Detection', frame_with_detections)

    # Save the frame to the output video
    out.write(frame_with_detections)

    # Press 'q' to exit the real-time detection
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the video capture and output file
cap.release()
out.release()

# Close all OpenCV windows
cv2.destroyAllWindows()
