In [None]:
import cv2
import numpy as np
import time
import os
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Traffic Violation Detection") \
    .getOrCreate()

# Paths to the required files
weights_path = 'Assets/Trained_Model/yolov3/yolov3.weights'
config_path = 'Assets/Trained_Model/yolov3/yolov3.cfg'
names_path = 'Assets/Trained_Model/coco.names'
traffic_light_weights_path = 'Assets/Trained_Model/yolov3/yolov3_10000.weights'
traffic_light_config_path = 'Assets/Trained_Model/yolov3/yolov3.cfg'
traffic_light_names_path = 'Assets/Structured files/obj.names'

# Read the class labels from coco.names and traffic_light.names
with open(names_path, 'r') as f:
    classes = f.read().strip().split('\n')

with open(traffic_light_names_path, 'r') as f:
    traffic_light_classes = f.read().strip().split('\n')

# Load the YOLO models with CUDA backend preference
net = cv2.dnn.readNetFromDarknet(config_path, weights_path)
net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)

traffic_light_net = cv2.dnn.readNetFromDarknet(traffic_light_config_path, traffic_light_weights_path)
traffic_light_net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
traffic_light_net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)

def detect_objects(frame, net, classes, confidence_threshold=0.5, nms_threshold=0.4):
    (H, W) = frame.shape[:2]
    
    # Create a blob and pass it through the network
    blob = cv2.dnn.blobFromImage(frame, 1 / 255.0, (416, 416), swapRB=True, crop=False)
    net.setInput(blob)
    
    # Get output layer names from YOLO network
    layer_names = net.getLayerNames()
    output_layers = [layer_names[i - 1] for i in net.getUnconnectedOutLayers().flatten()]
    
    outputs = net.forward(output_layers)
    
    boxes = []
    confidences = []
    class_ids = []
    
    for output in outputs:
        for detection in output:
            scores = detection[5:]
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            if confidence > confidence_threshold:
                label = classes[class_id]
                if label in ['car', 'motorbike', 'bus', 'RedLight', 'GreenLight']:
                    box = detection[0:4] * np.array([W, H, W, H])
                    (centerX, centerY, width, height) = box.astype("int")
                    x = int(centerX - (width / 2))
                    y = int(centerY - (height / 2))
                    boxes.append([x, y, int(width), int(height)])
                    confidences.append(float(confidence))
                    class_ids.append(class_id)
    
    idxs = cv2.dnn.NMSBoxes(boxes, confidences, confidence_threshold, nms_threshold)
    
    results = []
    if len(idxs) > 0:
        for i in idxs.flatten():
            (x, y) = (boxes[i][0], boxes[i][1])
            (w, h) = (boxes[i][2], boxes[i][3])
            if classes[class_ids[i]] == 'GreenLight':
                color = [0, 128, 0]  # Green color for GreenLight
            elif classes[class_ids[i]] == 'RedLight':
                color = [0, 0, 255]  # Red color for RedLight
            else:
                color = [0, 255, 0]  # Default color for other labels
            
            results.append((x, y, w, h, classes[class_ids[i]], confidences[i], color))
    
    return results

def process_frame(frame_count_frame):
    frame_count, frame = frame_count_frame
    red_light_line_y1 = 580  # Top Y-coordinate of the red light line
    red_light_line_y2 = 540  # Bottom Y-coordinate of the red light line
    intersection_line_start_x = 0  # Starting X-coordinate of the red light line
    intersection_line_end_x = 800  # Ending X-coordinate of the red light line

    start_time = time.time()  # Start time for frame processing
    
    # Object detection with the main YOLO model
    objects = detect_objects(frame, net, classes)
    
    # Traffic light detection with the custom trained YOLO model
    traffic_lights = detect_objects(frame, traffic_light_net, traffic_light_classes)
    
    # Determine the state of the traffic light
    red_light = any(label == 'RedLight' for _, _, _, _, label, _, _ in traffic_lights)
    
    violation_detected = False
    
    # Draw the detected objects
    for (x, y, w, h, label, confidence, color) in objects + traffic_lights:
        violation = False
        if red_light and label in ['car', 'motorbike', 'bus'] and y + h > 440 and y < 400:
            color = [0, 0, 255]  # Red color for violation
            violation = True
            violation_detected = True
        
        cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
        cv2.putText(frame, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        
        if violation:
            cv2.putText(frame, "Violated!", (x, y - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
    
    # Draw the shortened red light line
    cv2.line(frame, (intersection_line_start_x, red_light_line_y1), (intersection_line_end_x, red_light_line_y2), (0, 0, 255), 2)
    
    # Save the frame if a violation is detected
    if violation_detected:
        if frame_count % 15 == 0:  # Save every 15th frame
            violation_frame_path = os.path.join(violated_dir, f"violation_frame_{frame_count}.jpg")
            cv2.imwrite(violation_frame_path, frame)

    elapsed_time = time.time() - start_time
    print(f"Processing Time for frame {frame_count}: {elapsed_time:.2f} seconds")
    
    return (frame_count, frame)

# Path to the video
video_path = r'C:\School\Traffic Violation Detection Project\Assets\Traffic_Data\Traffic_Violations\No_Violation_Video\No Violation - Made with Clipchamp.mp4'
# Directory to save violated frames
violated_dir = 'Violated'
processed_dir = 'Processed'
os.makedirs(violated_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)

# Open the video
cap = cv2.VideoCapture(video_path)

# Check if the video opened successfully
if not cap.isOpened():
    print("Error: Cannot open video.")
    exit()

frames = []
frame_count = 0
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frames.append((frame_count, frame))
    frame_count += 1

cap.release()

# Distribute the frames across the Spark cluster
frames_rdd = spark.sparkContext.parallelize(frames)

# Process frames in parallel using Spark
processed_frames = frames_rdd.map(process_frame).collect()

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(os.path.join(processed_dir, 'processed_video.avi'), fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))

# Write the processed frames to the output video
for frame_count, frame in processed_frames:
    out.write(frame)

out.release()
cv2.destroyAllWindows()
spark.stop()