In [1]:
# AI Traffic Violation Detector (Helmet / Triple Riding / Approx Speed)
# Jupyter + Gradio prototype
#
# How it works:
# - Uses Ultralytics YOLO (if available) for detection (person, motorbike, helmet models)
# - For helmet detection: either a dedicated helmet model or heuristic on head-crop with a classifier
# - Triple riding: counts persons overlapping same motorbike bbox
# - Speed estimation: user provides meters_per_pixel and fps -> speed (m/s -> km/h)
#
# Notes: For best results install ultralytics and provide appropriate models:
#   pip install ultralytics
#   -> use 'yolov8n.pt' for general detection, and a helmet model (helmet.pt) if available.
#
# Paste this cell in Jupyter and run. The app will open via Gradio.

import os, io, sys, tempfile, json, time
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import gradio as gr
import subprocess

# ----------------- Optional installs (only if needed) -----------------
def safe_import_ultralytics():
    try:
        from ultralytics import YOLO
        return YOLO
    except Exception:
        return None

YOLO = safe_import_ultralytics()
YOLO_AVAILABLE = YOLO is not None

# ----------------- CONFIG / DEFAULT MODELS -----------------
# If you have models, put them in working directory and update paths here:
DEFAULT_YOLO_MODEL = "yolov8n.pt"   # general model (person, motorbike)
DEFAULT_HELMET_MODEL = None         # e.g., "helmet.pt" if you trained or downloaded one

# COCO class names (standard yolov8) - used if model is coco-trained
COCO_NAMES = [
 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat','traffic light',
 # truncated... but we only need 'person' and 'motorcycle'
]

# ----------------- Utility helpers -----------------
def load_yolo_model(path):
    if not YOLO_AVAILABLE:
        raise RuntimeError("Ultralytics YOLO not installed. Install via `pip install ultralytics` for best results.")
    if not os.path.exists(path):
        raise FileNotFoundError(f"YOLO model not found at: {path}")
    return YOLO(path)

def draw_boxes(img, detections, labels=None, colors=None, line_thickness=2):
    # img: numpy BGR
    # detections: list of dicts: {'xyxy':(x1,y1,x2,y2), 'label':str, 'conf':float}
    img_draw = img.copy()
    for det in detections:
        x1,y1,x2,y2 = map(int, det['xyxy'])
        lab = det.get('label', '')
        conf = det.get('conf', None)
        color = (0,255,0) if (colors is None) else colors.get(lab, (0,255,0))
        cv2.rectangle(img_draw, (x1,y1), (x2,y2), color, line_thickness)
        txt = f"{lab}"
        if conf is not None:
            txt += f" {conf:.2f}"
        # put text background
        ((tw,th),_) = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(img_draw, (x1, y1-th-6), (x1+tw+6, y1), color, -1)
        cv2.putText(img_draw, txt, (x1+3, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1)
    return img_draw

def bbox_iou(boxA, boxB):
    # boxes in x1,y1,x2,y2
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    if interArea == 0: return 0.0
    boxAArea = (boxA[2]-boxA[0])*(boxA[3]-boxA[1])
    boxBArea = (boxB[2]-boxB[0])*(boxB[3]-boxB[1])
    iou = interArea / (boxAArea + boxBArea - interArea + 1e-9)
    return iou

# ----------------- Detection pipeline -----------------
class Detector:
    def __init__(self, yolo_model_path=None, helmet_model_path=None, conf_thres=0.25):
        self.yolo_model_path = yolo_model_path or DEFAULT_YOLO_MODEL
        self.helmet_model_path = helmet_model_path or DEFAULT_HELMET_MODEL
        self.conf_thres = conf_thres
        self.yolo_model = None
        self.helmet_model = None
        if YOLO_AVAILABLE and os.path.exists(self.yolo_model_path):
            try:
                self.yolo_model = load_yolo_model(self.yolo_model_path)
            except Exception as e:
                print("Could not load YOLO model:", e)
        else:
            if not YOLO_AVAILABLE:
                print("Ultralytics not installed; detection will not use YOLO.")
            else:
                print(f"YOLO model not found at {self.yolo_model_path}; detection will try fallback.")
        # helmet model optional (if you have a custom model that predicts helmet vs no-helmet)
        if YOLO_AVAILABLE and self.helmet_model_path and os.path.exists(self.helmet_model_path):
            try:
                self.helmet_model = load_yolo_model(self.helmet_model_path)
            except Exception as e:
                print("Could not load helmet model:", e)

    def detect_frame(self, frame_bgr):
        """
        returns detections: list of dicts {'xyxy':(x1,y1,x2,y2), 'label':str, 'conf':float}
        labels of interest: 'person', 'motorcycle', 'helmet' (if available)
        """
        dets = []
        h,w = frame_bgr.shape[:2]
        if self.yolo_model:
            # run YOLO predict
            results = self.yolo_model.predict(frame_bgr, imgsz=640, conf=self.conf_thres, verbose=False)
            if len(results) > 0:
                r = results[0]
                boxes = r.boxes  # ultralytics Boxes object
                # r.boxes.cls (tensor), r.boxes.conf, r.boxes.xyxy
                for i in range(len(boxes)):
                    cls_id = int(boxes.cls[i].item())
                    conf = float(boxes.conf[i].item())
                    xyxy = boxes.xyxy[i].cpu().numpy().tolist()
                    # label mapping: if model is coco, class names 'person'=0, 'motorcycle'=3 usually; but safer to use .names
                    label = str(r.names[cls_id]) if hasattr(r, "names") else str(cls_id)
                    dets.append({"xyxy":xyxy, "label":label, "conf":conf})
        else:
            # fallback: use OpenCV DNN MobileNet-SSD to detect persons (coarse)
            # load model files if present in working directory
            prototxt = "MobileNetSSD_deploy.prototxt"
            caffemodel = "MobileNetSSD_deploy.caffemodel"
            if os.path.exists(prototxt) and os.path.exists(caffemodel):
                net = cv2.dnn.readNetFromCaffe(prototxt, caffemodel)
                blob = cv2.dnn.blobFromImage(frame_bgr, 0.007843, (300,300), 127.5)
                net.setInput(blob)
                detections = net.forward()
                for i in range(detections.shape[2]):
                    conf = float(detections[0,0,i,2])
                    if conf < self.conf_thres: continue
                    cls_id = int(detections[0,0,i,1])
                    # MobileNetSSD class 15 is person (depends on prototxt)
                    # We map person class id to label 'person' for safety
                    if cls_id == 15:
                        box = detections[0,0,i,3:7] * np.array([frame_bgr.shape[1], frame_bgr.shape[0],
                                                                 frame_bgr.shape[1], frame_bgr.shape[0]])
                        x1,y1,x2,y2 = box.astype("int").tolist()
                        dets.append({"xyxy":(x1,y1,x2,y2), "label":"person", "conf":conf})
            else:
                # no detector available
                pass
        return dets

# ----------------- Violation logic -----------------
def analyze_detections(dets, iou_thresh=0.3):
    """
    From detections create associations:
      - motorbike boxes -> riders (persons overlapping)
      - helmet check: if a helmet detection overlaps with a person head region -> helmet present
    Returns violations list and annotated detection info
    """
    bikes = [d for d in dets if d['label'] in ('motorcycle','motorbike','motor-bike','motorbike'.lower()) or d['label']=='motorcycle']
    persons = [d for d in dets if d['label']=='person']
    helmets = [d for d in dets if d['label'] in ('helmet','hardhat')]  # if helmet model used
    # Normalize labels: YOLO naming may be 'motorcycle' or 'motorbike' depending on dataset; allow both by checking substring
    bikes = [d for d in dets if ('motor' in d['label'].lower())]
    persons = [d for d in dets if d['label'].lower()=='person']

    violations = []
    associations = []  # list of dicts per bike {'bike':bike_det, 'riders':[person_dets], 'helmet_flags':[True/False per rider]}
    for bike in bikes:
        bx = bike['xyxy']
        riders = []
        helmet_flags = []
        for p in persons:
            px = p['xyxy']
            iou = bbox_iou(bx, px)
            # also allow overlap area percentage
            if iou > 0.01:  # low threshold: some overlap
                riders.append(p)
                # helmet check by overlap with any helmet detection
                has_helmet = False
                for h in helmets:
                    if bbox_iou(h['xyxy'], px) > 0.2:
                        has_helmet = True
                        break
                helmet_flags.append(has_helmet)
        associations.append({"bike":bike, "riders":riders, "helmets":helmet_flags})
        # triple riding violation
        if len(riders) >= 3:
            violations.append({"type":"triple_riding", "bike":bike, "count":len(riders)})
        # helmet violations: for each rider, if helmet flag False -> violation
        for idx, has in enumerate(helmet_flags):
            if not has:
                viol = {"type":"no_helmet", "bike":bike, "rider":riders[idx]}
                violations.append(viol)
    # persons not associated to any bike: check if riding without bike? ignore
    return violations, associations

# ----------------- Speed estimation -----------------
def estimate_speeds(track_history, meters_per_pixel, fps):
    """
    track_history: dict per track_id -> list of centers [(x,y,frame_idx), ...]
    meters_per_pixel: real world meters per pixel (user must provide)
    fps: frames per second of video
    returns speeds dict track_id->speed_kmph (based on last two positions)
    """
    speeds = {}
    for tid, pts in track_history.items():
        if len(pts) < 2:
            speeds[tid] = 0.0
            continue
        # use last two positions
        x1,y1,f1 = pts[-2]
        x2,y2,f2 = pts[-1]
        dx = (x2 - x1)
        dy = (y2 - y1)
        pixel_dist = np.sqrt(dx*dx + dy*dy)
        meters = pixel_dist * meters_per_pixel
        # time delta in seconds between frames
        dt = (f2 - f1) / fps if fps>0 else 1.0/fps if fps else 1.0
        if dt <= 0:
            speed_m_s = 0.0
        else:
            speed_m_s = meters / dt
        speed_kmph = speed_m_s * 3.6
        speeds[tid] = speed_kmph
    return speeds

# ----------------- Simple multi-object tracker (centroid + greedy) -----------------
def simple_tracker_assign(prev_centroids, curr_centroids, max_dist=50):
    """
    prev_centroids: dict id -> (x,y)
    curr_centroids: list of (x,y)
    returns mapping new_id_for_curr_index, updated prev dict
    """
    assigned = {}
    used_prev = set()
    new_prev = {}
    for i, c in enumerate(curr_centroids):
        best_id = None
        best_d = None
        for pid, pc in prev_centroids.items():
            d = np.hypot(c[0]-pc[0], c[1]-pc[1])
            if d <= max_dist and (best_d is None or d < best_d) and pid not in used_prev:
                best_d = d; best_id = pid
        if best_id is not None:
            assigned[i] = best_id
            used_prev.add(best_id)
            new_prev[best_id] = c
        else:
            # new id
            new_id = max(prev_centroids.keys() or [0]) + len(new_prev) + 1
            assigned[i] = new_id
            new_prev[assigned[i]] = c
    return assigned, new_prev

# ----------------- Main processing for video -----------------
def process_video(video_path, detector:Detector, meters_per_pixel=None, fps=None, speed_threshold_kmph=None, save_out="annotated_out.mp4"):
    """
    Reads video, applies detection per frame, builds simple tracking for motorbikes, detects violations, writes annotated video.
    Returns path to annotated video and report (list of violation records).
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError("Could not open video: "+video_path)
    video_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    out_video = None
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out_video = cv2.VideoWriter(save_out, fourcc, video_fps, (frame_w, frame_h))

    # trackers history: track_id -> list of (cx,cy,frame_idx)
    bike_tracks = {}  # id -> history
    prev_centroids = {}
    next_track_id = 1

    reports = []  # list of violation dicts with frame index, type, details

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_idx += 1
        dets = detector.detect_frame(frame)
        # identify bikes and persons
        bikes = [d for d in dets if 'motor' in d['label'].lower()]
        persons = [d for d in dets if d['label'].lower()=='person']
        # compute centroids of bikes
        curr_centroids = []
        bike_map = []  # maps centroid index -> bike det
        for b in bikes:
            x1,y1,x2,y2 = map(int, b['xyxy'])
            cx = int((x1+x2)/2); cy = int((y1+y2)/2)
            curr_centroids.append((cx,cy))
            bike_map.append(b)
        # assign trackers
        assigned, new_prev = simple_tracker_assign(prev_centroids, curr_centroids, max_dist=60)
        # rebuild bike_tracks based on assigned
        new_bike_tracks = {}
        for idx, tid in assigned.items():
            cx,cy = curr_centroids[idx]
            if tid in bike_tracks:
                hist = bike_tracks[tid]
            else:
                hist = []
            hist.append((cx,cy,frame_idx))
            new_bike_tracks[tid] = hist
        bike_tracks = new_bike_tracks
        prev_centroids = new_prev

        # Build combined detections for associations and violation checking
        # Convert persons to (xyxy) arrays and check overlap with bikes
        combined_dets = dets
        violations, associations = analyze_detections(combined_dets)
        # Speed estimates
        speed_info = {}
        if meters_per_pixel and fps:
            speed_info = estimate_speeds(bike_tracks, meters_per_pixel, fps)
            # check speed violations
            if speed_threshold_kmph:
                for tid,speed in speed_info.items():
                    if speed > speed_threshold_kmph:
                        # find approximate bike bbox (closest centroid)
                        reports.append({"frame":frame_idx, "type":"speeding", "track_id":tid, "speed_kmph":round(speed,2)})
                        # mark on frame
        # Add annotation info to frame
        ann = []
        for d in combined_dets:
            ann.append(d)
        # draw boxes
        color_map = {"person":(0,200,255), "motorcycle":(0,255,0), "motorbike":(0,255,0), "helmet":(0,128,255)}
        frame_annot = draw_boxes(frame, ann, colors=color_map, line_thickness=2)
        # overlay violations as text
        y0 = 30
        for v in violations:
            if v['type']=='triple_riding':
                txt = f"Triple Riding detected (count={v['count']})"
                cv2.putText(frame_annot, txt, (10,y0), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2); y0 += 25
                reports.append({"frame":frame_idx, "type":"triple_riding", "count":v['count']})
            elif v['type']=='no_helmet':
                txt = f"No Helmet detected!"
                cv2.putText(frame_annot, txt, (10,y0), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2); y0 += 25
                reports.append({"frame":frame_idx, "type":"no_helmet"})
        # show speed info
        if speed_info:
            for tid, sp in speed_info.items():
                txt = f"Track {tid}: {sp:.1f} km/h"
                cv2.putText(frame_annot, txt, (frame_w-220, 20 + 18*tid), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,0), 2)
        out_video.write(frame_annot)

    cap.release()
    out_video.release()
    return save_out, reports, {"fps":video_fps, "frames": total_frames}

# ----------------- Gradio App -----------------
def run_on_input(image_file, video_file, yolo_model_file, helmet_model_file, meters_per_pixel, fps_input, speed_limit_kmph):
    """
    If video_file provided, process video; else if image_file provided, run detection on single image.
    """
    detector = Detector(yolo_model_path=yolo_model_file.name if hasattr(yolo_model_file, "name") and yolo_model_file.name else (yolo_model_file if isinstance(yolo_model_file,str) else DEFAULT_YOLO_MODEL),
                        helmet_model_path=helmet_model_file.name if hasattr(helmet_model_file,"name") and helmet_model_file.name else None)
    # parse meters_per_pixel and fps
    try:
        meters_per_pixel = float(meters_per_pixel) if meters_per_pixel is not None else None
    except:
        meters_per_pixel = None
    try:
        fps_val = float(fps_input) if fps_input is not None else None
    except:
        fps_val = None

    results = {}
    if video_file is not None:
        # save uploaded file to temp
        tmpv = "input_video.mp4"
        with open(tmpv, "wb") as f:
            f.write(video_file.read())
        out_path, reports, meta = process_video(tmpv, detector, meters_per_pixel=meters_per_pixel, fps=fps_val or meta.get("fps",25.0), speed_threshold_kmph=speed_limit_kmph or None, save_out="annotated_out.mp4")
        # return annotated video path and reports
        return None, out_path, json.dumps(reports, indent=2)
    elif image_file is not None:
        # single image detection
        if hasattr(image_file, "name"):
            img = cv2.imread(image_file.name)
        else:
            b = image_file.read()
            arr = np.frombuffer(b, np.uint8)
            img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
        dets = detector.detect_frame(img)
        violations, associations = analyze_detections(dets)
        ann = draw_boxes(img, dets)
        # mark violations text
        y0 = 30
        repr_reports = []
        for v in violations:
            if v['type']=='triple_riding':
                cv2.putText(ann, f"Triple Riding (count={v['count']})", (10,y0), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2); y0+=25
                repr_reports.append({"type":"triple_riding", "count":v['count']})
            elif v['type']=='no_helmet':
                cv2.putText(ann, "No Helmet!", (10,y0), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2); y0+=25
                repr_reports.append({"type":"no_helmet"})
        # convert ann to PIL for gr.Image
        ann_rgb = cv2.cvtColor(ann, cv2.COLOR_BGR2RGB)
        pil_ann = Image.fromarray(ann_rgb)
        return pil_ann, None, json.dumps(repr_reports, indent=2)
    else:
        return None, None, "Please upload an image or a video."

# Build Gradio UI
with gr.Blocks(title="AI Traffic Violation Detector (Helmet / Triple / Speed)") as demo:
    gr.Markdown("## ðŸš¨ AI Traffic Violation Detector â€” Helmet, Triple Riding & Approx Speed")
    with gr.Row():
        with gr.Column(scale=1):
            image_in = gr.File(label="Upload Image (jpg/png)", file_types=[".png",".jpg",".jpeg"])
            video_in = gr.File(label="Or Upload Video (mp4)", file_types=[".mp4",".avi"], interactive=True)
            yolo_model_in = gr.File(label="YOLO model (optional .pt) e.g., yolov8n.pt", file_count="single")
            helmet_model_in = gr.File(label="Helmet model (optional .pt)", file_count="single")
            meters_per_pixel = gr.Number(label="Meters per pixel (for speed estimation, e.g., 0.02)", value=None)
            fps_val = gr.Number(label="FPS of video (if known)", value=25)
            speed_limit = gr.Number(label="Speed limit (km/h) to flag speeding (optional)", value=60)
            run_btn = gr.Button("Run Detection")
        with gr.Column(scale=1):
            out_image = gr.Image(label="Annotated Image (if image uploaded)", type="pil")
            out_video = gr.Video(label="Annotated Video (if video uploaded)")
            report_box = gr.Textbox(label="Violation Report (JSON)", lines=12)

    run_btn.click(fn=run_on_input, inputs=[image_in, video_in, yolo_model_in, helmet_model_in, meters_per_pixel, fps_val, speed_limit],
                  outputs=[out_image, out_video, report_box])

demo.launch()


* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.


