In [None]:
import os
import csv
import datetime
from collections import defaultdict
import cv2
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
from dateutil import parser
import matplotlib.pyplot as plt
from IPython.display import clear_output

from .config import *
from .utils import traffic_situation

def process_video(road_status):
    model = YOLO(MODEL_PATH).to('cuda')
    tracker = DeepSort(max_age=30)
    YOLO_CLASSES = model.model.names
    VEHICLE_CLASSES = {'car', 'motorcycle', 'bus', 'truck'}

    if not os.path.exists(os.path.dirname(CSV_PATH)):
        os.makedirs(os.path.dirname(CSV_PATH), exist_ok=True)

    if not os.path.exists(CSV_PATH):
        with open(CSV_PATH, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['time', 'date', 'day_of_week',
                             'car_count', 'motorcycle_count', 'truck_count', 'bus_count',
                             'total', 'avg_exit_time_seconds', 'road_status', 'traffic_situation'])

    cap = cv2.VideoCapture(VIDEO_PATH)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration_seconds = frame_count / fps

    if duration_seconds <= 30:
        print("⏩ Video duration too short. Skipping processing.")
        return

    line_y = int(height * VDZ_POSITION)
    counted_ids = set()
    vehicle_counts = defaultdict(int)

    start_datetime = parser.parse(START_TIME_STR)
    last_log_time = start_datetime

    first_seen = {}
    last_seen = {}
    exit_times = []

    frame_num = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if USE_ROI:
            h, w, _ = frame.shape
            frame = frame[CROP_TOP:h - CROP_BOTTOM, CROP_LEFT:w - CROP_RIGHT]

        results = model(frame, verbose=False)[0]

        detections = []
        for box, cls, conf in zip(results.boxes.xyxy, results.boxes.cls, results.boxes.conf):
            class_name = YOLO_CLASSES[int(cls)]
            if class_name in VEHICLE_CLASSES:
                x1, y1, x2, y2 = map(int, box.tolist())
                detections.append(([x1, y1, x2 - x1, y2 - y1], conf.item(), class_name))

        tracks = tracker.update_tracks(detections, frame=frame)
        current_ids = {track.track_id for track in tracks if track.is_confirmed()}
        disappeared_ids = set(first_seen.keys()) - current_ids

        for disappeared_id in disappeared_ids:
            if disappeared_id in last_seen:
                duration = (last_seen[disappeared_id] - first_seen[disappeared_id]).total_seconds()
                exit_times.append(duration)
                del first_seen[disappeared_id]
                del last_seen[disappeared_id]

        for track in tracks:
            if not track.is_confirmed():
                continue

            track_id = track.track_id
            current_time = start_datetime + datetime.timedelta(seconds=frame_num / fps)

            if track_id not in first_seen:
                first_seen[track_id] = current_time
            last_seen[track_id] = current_time

            l, t, r, b = track.to_ltrb()
            cx = int((l + r) / 2)
            cy = int((t + b) / 2)
            class_name = track.get_det_class()

            if SHOW_DEBUG:
                cv2.rectangle(frame, (int(l), int(t)), (int(r), int(b)), (0, 255, 255), 2)
                cv2.putText(frame, f"{class_name} ID:{track_id}", (int(l), int(t)-5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)

            if (line_y - 10) < cy < (line_y + 10) and track_id not in counted_ids:
                counted_ids.add(track_id)
                vehicle_counts[class_name] += 1

        if SHOW_DEBUG:
            cv2.line(frame, (0, line_y), (width, line_y), (0, 255, 0), 2)
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            clear_output(wait=True)
            plt.imshow(rgb_frame)
            plt.axis('off')
            plt.show()

        current_time = start_datetime + datetime.timedelta(seconds=frame_num / fps)
        if (current_time - last_log_time).seconds >= 60 or frame_num == frame_count - 1:
            avg_exit_time = sum(exit_times) / len(exit_times) if exit_times else 0
            total = sum(vehicle_counts.values())

            row = [
                current_time.strftime('%H:%M:%S'),
                current_time.strftime('%Y-%m-%d'),
                current_time.strftime('%A'),
                vehicle_counts['car'],
                vehicle_counts['motorcycle'],
                vehicle_counts['truck'],
                vehicle_counts['bus'],
                total,
                avg_exit_time,
                road_status,
                traffic_situation(current_time, frame_num, SITUATION_ALL)
            ]

            with open(CSV_PATH, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(row)

            vehicle_counts = defaultdict(int)
            last_log_time = current_time
            counted_ids.clear()
            exit_times.clear()

            remaining_time = duration_seconds - (frame_num / fps)
            if remaining_time <= 40:
                print("⏹️ Stopping: Only 40 seconds or less remain in the video.")
                break

        frame_num += 1

    cap.release()
