In [None]:
!pip install mlflow
!pip install opencv-python-headless
!pip install numpy
!pip install pyspark


In [9]:
import mlflow
import mlflow.pyfunc
import cv2
import numpy as np

class YOLOModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.net = cv2.dnn.readNetFromDarknet(context.artifacts['config_path'], context.artifacts['weights_path'])
        with open(context.artifacts['classes_path'], 'r') as f:
            self.classes = f.read().strip().split('\n')
        
    def predict(self, context, model_input):
        frame = model_input['frame']
        confidence_threshold = model_input['confidence_threshold']
        nms_threshold = model_input['nms_threshold']
        return self.detect_objects(frame, confidence_threshold, nms_threshold)

    def detect_objects(self, frame, confidence_threshold=0.5, nms_threshold=0.4):
        (H, W) = frame.shape[:2]
        
        blob = cv2.dnn.blobFromImage(frame, 1 / 255.0, (416, 416), swapRB=True, crop=False)
        self.net.setInput(blob)
        
        layer_names = self.net.getLayerNames()
        output_layers = [layer_names[i - 1] for i in self.net.getUnconnectedOutLayers().flatten()]
        
        outputs = self.net.forward(output_layers)
        
        boxes = []
        confidences = []
        class_ids = []
        
        for output in outputs:
            for detection in output:
                scores = detection[5:]
                class_id = np.argmax(scores)
                confidence = scores[class_id]
                if confidence > confidence_threshold:
                    box = detection[0:4] * np.array([W, H, W, H])
                    (centerX, centerY, width, height) = box.astype("int")
                    x = int(centerX - (width / 2))
                    y = int(centerY - (height / 2))
                    boxes.append([x, y, int(width), int(height)])
                    confidences.append(float(confidence))
                    class_ids.append(class_id)
        
        idxs = cv2.dnn.NMSBoxes(boxes, confidences, confidence_threshold, nms_threshold)
        
        results = []
        if len(idxs) > 0:
            for i in idxs.flatten():
                (x, y) = (boxes[i][0], boxes[i][1])
                (w, h) = (boxes[i][2], boxes[i][3])
                results.append((x, y, w, h, self.classes[class_ids[i]], confidences[i]))
        
        return results

# Paths to the required files
weights_path = 'Assets/Trained_Model/yolov3/yolov3.weights'
config_path = 'Assets/Trained_Model/yolov3/yolov3.cfg'
names_path = 'Assets/Trained_Model/coco.names'

# Read the class labels from coco.names
with open(names_path, 'r') as f:
    classes = f.read().strip().split('\n')

# Save the classes to a file
classes_path = 'Assets/Trained_Model/classes.txt'
with open(classes_path, 'w') as f:
    f.write('\n'.join(classes))

# Log the model
mlflow.pyfunc.save_model(
    path="yolo_model",
    python_model=YOLOModel(),
    artifacts={
        "weights_path": weights_path,
        "config_path": config_path,
        "classes_path": classes_path
    }
)

In [None]:
import mlflow
import mlflow.pyfunc
import cv2
import os
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Traffic Violation Detection with MLflow") \
    .getOrCreate()

# Load the registered YOLO model
yolo_model = mlflow.pyfunc.load_model("yolo_model")

def process_frame(frame_count_frame):
    frame_count, frame = frame_count_frame
    
    model_input = {
        'frame': frame,
        'confidence_threshold': 0.5,
        'nms_threshold': 0.4
    }
    
    # Predict using the YOLO model
    results = yolo_model.predict(model_input)

    violation_detected = False
    
    for (x, y, w, h, label, confidence) in results:
        if label == 'RedLight':
            color = [0, 0, 255]
        elif label == 'GreenLight':
            color = [0, 128, 0]
        else:
            color = [0, 255, 0]
        
        cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
        cv2.putText(frame, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        
        if label in ['car', 'motorbike', 'bus'] and color == [0, 0, 255]:
            violation_detected = True
            cv2.putText(frame, "Violated!", (x, y - 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
    
    if violation_detected:
        if frame_count % 15 == 0:
            violation_frame_path = os.path.join('Violated', f"violation_frame_{frame_count}.jpg")
            cv2.imwrite(violation_frame_path, frame)
    
    return (frame_count, frame)

# Path to the video
video_path = r'C:\School\Traffic Violation Detection Project\Assets\Traffic_Data\Traffic_Violations\No_Violation_Video\No Violation - Made with Clipchamp.mp4'
cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
    print("Error: Cannot open video.")
    exit()

frames = []
frame_count = 0
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frames.append((frame_count, frame))
    frame_count += 1

cap.release()

frames_rdd = spark.sparkContext.parallelize(frames)
processed_frames = frames_rdd.map(process_frame).collect()

fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('processed_video.avi', fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))

for frame_count, frame in processed_frames:
    out.write(frame)

out.release()
cv2.destroyAllWindows()
spark.stop()