In [1]:
import time
import cv2
import easyocr
import re
import csv
import torch
import os
import pandas as pd
import random
import numpy as np
import json
from ultralytics import YOLO
from tqdm import tqdm  # tqdm 라이브러리 추가
import gc  # 가비지 컬렉션 추가

# 번호판 유효성 검사를 위한 정규표현식 컴파일
license_plate_pattern = re.compile(r'[0-9]{2,3}[가-힣0-9]{1}[0-9]{4}')

# 1. 설정 로드 함수 및 설정 값
def load_config():
    with open("C:/Users/PC/Desktop/caffein/montana/config.json", "r", encoding="utf-8") as f:
        return json.load(f)

_config = load_config()

def get_video_folders():
    root_folder = _config['video_root_folder']
    return sorted([os.path.join(root_folder, folder) for folder in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, folder))])

def get_video_files_in_folder(folder):
    return sorted([os.path.join(folder, file) for file in os.listdir(folder) if file.endswith('.avi')])

def get_results_folder():
    return _config['results_folder']

def get_model_paths():
    return _config['model_paths']

# 2. 폰트 설정 함수
def set_malgun_gothic_font():
    font_name = "Malgun Gothic"
    rc('font', family=font_name)

# 3. 번호판 교정 함수
def correct_plate_number(plate):
    corrections = get_corrections()
    if re.fullmatch(r'\d{7}', plate):
        if plate[-5] in corrections:
            plate = plate[:-5] + corrections[plate[-5]] + plate[-4:]
    elif re.fullmatch(r'\d{8}', plate):
        if plate[-5] in corrections:
            plate = plate[:-5] + corrections[plate[-5]] + plate[-4:]
    return plate

# 4. IoU 기반 추적 함수
def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2

    xi1 = max(x1, x1_p)
    yi1 = max(y1, y1_p)
    xi2 = min(x2, x2_p)
    yi2 = min(y2, y2_p)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x2_p - x1_p) * (y2_p - y1_p)
    union_area = box1_area + box2_area - inter_area

    if union_area == 0:
        return 0.0

    iou = inter_area / union_area
    return iou

def assign_ids_to_boxes(boxes, state, best_ocr_results, csv_filename, max_frames_missing=20, iou_threshold=0.2):
    new_tracked_objects = {}
    unmatched_previous_objects = set(state['tracked_objects'].keys())

    for bbox in boxes:
        x1, y1, x2, y2 = map(int, bbox.xyxy[0])
        confidence = bbox.conf[0]
        cls = bbox.cls[0]
        center = ((x1 + x2) // 2, (y1 + y2) // 2)

        best_iou = 0
        best_id = None

        for obj_id, obj in state['tracked_objects'].items():
            existing_bbox = obj['bbox']
            iou = calculate_iou((x1, y1, x2, y2), existing_bbox)
            if iou > best_iou:
                best_iou = iou
                best_id = obj_id

        if best_iou > iou_threshold:
            obj = state['tracked_objects'][best_id]
            new_tracked_objects[best_id] = {
                'bbox': (x1, y1, x2, y2),
                'center': center,
                'color': cls,
                'confidence': confidence,
                'best_ocr': obj['best_ocr'],
                'frames_missing': 0,
                'trajectory': obj['trajectory'] + [center],
                'direction': obj['direction']
            }
            unmatched_previous_objects.discard(best_id)
        else:
            new_tracked_objects[state['next_id']] = {
                'bbox': (x1, y1, x2, y2),
                'center': center,
                'color': cls,
                'confidence': confidence,
                'best_ocr': None,
                'frames_missing': 0,
                'trajectory': [center],
                'direction': None
            }
            state['next_id'] += 1

    # 오래 추적되지 않은 객체 정리
    for obj_id in unmatched_previous_objects:
        obj = state['tracked_objects'][obj_id]
        if obj['frames_missing'] >= max_frames_missing:
            if obj_id in best_ocr_results:
                save_to_csv({obj_id: best_ocr_results[obj_id]}, csv_filename)  # CSV 저장
                del best_ocr_results[obj_id]  # 메모리 해제
            gc.collect()  # 메모리 해제
        else:
            new_tracked_objects[obj_id] = {
                'bbox': obj['bbox'],
                'center': obj['center'],
                'color': obj['color'],
                'confidence': obj['confidence'],
                'best_ocr': obj['best_ocr'],
                'frames_missing': obj['frames_missing'] + 1,
                'trajectory': obj['trajectory'],
                'direction': obj['direction']
            }

    state['tracked_objects'] = new_tracked_objects
    return state

def create_tracking_state():
    return {
        'tracked_objects': {},
        'next_id': 0,
    }

# 5. 기타 유틸리티 함수
def ensure_folder_exists(folder):
    os.makedirs(folder, exist_ok=True)

def initialize_csv(csv_filename):
    with open(csv_filename, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['video', 'ID', 'color', 'ocr', 'accuracy', 'direction', 'frame'])


def save_to_csv(best_ocr_results, csv_filename):
    with open(csv_filename, mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        for data in best_ocr_results.values():
            writer.writerow([
                data['video'], data['ID'], data['color'],
                data['ocr'], data['accuracy'], data['direction'],
                data['frame']
            ])

# 6. 입차/출차 판단 함수
def determine_direction(trajectory):
    if len(trajectory) >= 2:
        y_positions = [pos[1] for pos in trajectory]
        if y_positions[-1] - y_positions[0] > 50:
            return '입차'
        elif y_positions[0] - y_positions[-1] > 50:
            return '출차'
    return None

# 7. 번호판 이미지 전처리 함수
def preprocess_plate_image(img):
    gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    morphed_img = cv2.fastNlMeansDenoising(gray_img, None, 10, 7, 21)
    morphed_img = cv2.equalizeHist(morphed_img)
    return morphed_img

# 8. 모델 로드 함수
def load_model(model_path, device):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {model_path}")
    model = YOLO(model_path).to(device)
    return model

# 9. 비디오 처리 함수
def process_video(video_path, csv_filename, plate_model, color_model, reader, device, batch_size=64):
    initialize_csv(csv_filename)
    best_ocr_results = {}
    state = create_tracking_state()

    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    with tqdm(total=total_frames, desc=f'Processing {os.path.basename(video_path)}') as pbar:
        frame_count = 0
        batch_frames = []

        while True:
            ret, img = cap.read()
            if not ret:
                break

            batch_frames.append(img)

            # 배치 처리: batch_size만큼 프레임이 쌓이면 모델에 전달
            if len(batch_frames) == batch_size:
                plate_results = plate_model.predict(source=batch_frames, imgsz=416, verbose=False)
                color_results = color_model.predict(source=batch_frames, imgsz=640, conf=0.3, verbose=False)

                # 각 프레임에 대한 처리
                for i, img in enumerate(batch_frames):
                    current_frame_number = frame_count + i
                    state = assign_ids_to_boxes(
                        color_results[i].boxes, state, best_ocr_results, csv_filename, max_frames_missing=15
                    )

                    for obj_id, obj in state['tracked_objects'].items():
                        x1, y1, x2, y2 = obj['bbox']
                        center_x, center_y = obj['center']
                        class_name = color_model.names[int(obj['color'])]
                        confidence = obj['confidence']

                        # 경계 상자 및 색상 관련 코드 모두 제거
                        # 필요 없는 cv2.rectangle, color 할당 코드 제거

                        # 필요한 텍스트 정보만 출력
                        cv2.putText(img, f'ID: {obj_id}', (x1, y1 - 10),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # ID 출력
                        cv2.putText(img, f'Conf: {confidence:.2f}', (x1, y1 - 30),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # Confidence 출력
                        cv2.putText(img, f'Class: {class_name}', (x1, y1 - 50),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # Class 출력


                        if obj['direction'] is None:
                            obj['direction'] = determine_direction(obj['trajectory'])
                            if obj['direction'] is not None:
                                if obj_id in best_ocr_results:
                                    best_ocr_results[obj_id]['direction'] = obj['direction']

                    # 번호판 OCR 처리
                    for bbox in plate_results[i].boxes:
                        px1, py1, px2, py2 = map(int, bbox.xyxy[0])
                        plate_confidence = bbox.conf[0]
                        plate_color = (0, 0, 255)
                        cv2.rectangle(img, (px1, py1), (px2, py2), plate_color, 2)

                        plate_center_x = (px1 + px2) // 2
                        plate_center_y = (py1 + py2) // 2

                        matched_obj_id = None
                        for obj_id, obj in state['tracked_objects'].items():
                            x1_obj, y1_obj, x2_obj, y2_obj = obj['bbox']
                            if x1_obj < plate_center_x < x2_obj and y1_obj < plate_center_y < y2_obj:
                                matched_obj_id = obj_id
                                break

                        if matched_obj_id is not None:
                            plate_cropped_img = img[py1:py2, px1:px2]
                            preprocessed_plate_img = preprocess_plate_image(plate_cropped_img)

                            ocr_result = reader.readtext(preprocessed_plate_img)
                            if ocr_result:
                                _, text, prob = ocr_result[0]
                                text = re.sub('[^가-힣0-9]', '', text)

                                if license_plate_pattern.fullmatch(text):
                                    if (matched_obj_id not in best_ocr_results) or (prob > best_ocr_results[matched_obj_id]['accuracy']):
                                        best_ocr_results[matched_obj_id] = {
                                            'video': video_path,
                                            'ID': matched_obj_id,
                                            'color': color_model.names[int(state['tracked_objects'][matched_obj_id]['color'])],
                                            'ocr': text,
                                            'accuracy': prob,
                                            'direction': state['tracked_objects'][matched_obj_id]['direction'],
                                            'frame': current_frame_number
                                        }
                                        state['tracked_objects'][matched_obj_id]['best_ocr'] = (text, prob)

                frame_count += len(batch_frames)
                batch_frames = []
                pbar.update(batch_size)

        if len(batch_frames) > 0:
            # 남은 프레임 처리
            plate_results = plate_model.predict(source=batch_frames, imgsz=416, verbose=False)
            color_results = color_model.predict(source=batch_frames, imgsz=640, conf=0.3, verbose=False)

            # 각 프레임에 대한 처리 (배치 처리와 동일하게 수행)
            for i, img in enumerate(batch_frames):
                current_frame_number = frame_count + i
                state = assign_ids_to_boxes(
                    color_results[i].boxes, state, best_ocr_results, csv_filename, max_frames_missing=15
                )

                for obj_id, obj in state['tracked_objects'].items():
                    x1, y1, x2, y2 = obj['bbox']
                    center_x, center_y = obj['center']
                    class_name = color_model.names[int(obj['color'])]
                    confidence = obj['confidence']

                    # 경계 상자 및 색상 관련 코드 모두 제거

                    # 필요한 텍스트 정보만 출력
                    cv2.putText(img, f'ID: {obj_id}', (x1, y1 - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # ID 출력
                    cv2.putText(img, f'Conf: {confidence:.2f}', (x1, y1 - 30),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # Confidence 출력
                    cv2.putText(img, f'Class: {class_name}', (x1, y1 - 50),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)  # Class 출력


                    if obj['direction'] is None:
                        obj['direction'] = determine_direction(obj['trajectory'])
                        if obj['direction'] is not None:
                            if obj_id in best_ocr_results:
                                best_ocr_results[obj_id]['direction'] = obj['direction']

                # 번호판 OCR 처리
                for bbox in plate_results[i].boxes:
                    px1, py1, px2, py2 = map(int, bbox.xyxy[0])
                    plate_confidence = bbox.conf[0]
                    plate_color = (0, 0, 255)
                    cv2.rectangle(img, (px1, py1), (px2, py2), plate_color, 2)

                    plate_center_x = (px1 + px2) // 2
                    plate_center_y = (py1 + py2) // 2

                    matched_obj_id = None
                    for obj_id, obj in state['tracked_objects'].items():
                        x1_obj, y1_obj, x2_obj, y2_obj = obj['bbox']
                        if x1_obj < plate_center_x < x2_obj and y1_obj < plate_center_y < y2_obj:
                            matched_obj_id = obj_id
                            break

                    if matched_obj_id is not None:
                        plate_cropped_img = img[py1:py2, px1:px2]
                        preprocessed_plate_img = preprocess_plate_image(plate_cropped_img)

                        ocr_result = reader.readtext(preprocessed_plate_img)
                        if ocr_result:
                            _, text, prob = ocr_result[0]
                            text = re.sub('[^가-힣0-9]', '', text)

                            if license_plate_pattern.fullmatch(text):
                                if (matched_obj_id not in best_ocr_results) or (prob > best_ocr_results[matched_obj_id]['accuracy']):
                                    best_ocr_results[matched_obj_id] = {
                                        'video': video_path,
                                        'ID': matched_obj_id,
                                        'color': color_model.names[int(state['tracked_objects'][matched_obj_id]['color'])],
                                        'ocr': text,
                                        'accuracy': prob,
                                        'direction': state['tracked_objects'][matched_obj_id]['direction'],
                                        'frame': current_frame_number
                                    }
                                    state['tracked_objects'][matched_obj_id]['best_ocr'] = (text, prob)

            # 마지막 남은 프레임 처리 후 업데이트
            frame_count += len(batch_frames)
            pbar.update(len(batch_frames))


    cap.release()
    save_to_csv(best_ocr_results, csv_filename)
    return frame_count

def main():
    results_folder = get_results_folder()
    model_paths = get_model_paths()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    plate_model = load_model(model_paths['plate_model'], device)
    color_model = load_model(model_paths['color_model'], device)

    reader = easyocr.Reader(['ko'], gpu=torch.cuda.is_available())

    for folder in get_video_folders():
        video_files = get_video_files_in_folder(folder)

        for video_path in video_files:
            # 비디오 파일별로 CSV 파일 생성
            video_filename = os.path.basename(video_path).split('.')[0]
            csv_filename = os.path.join(results_folder, f'{video_filename}_res.csv')
            
            # CSV 파일 초기화 및 비디오 처리
            initialize_csv(csv_filename)
            process_video(video_path, csv_filename, plate_model, color_model, reader, device)


if __name__ == "__main__":
    main()


Processing 0901_1.avi:  40%|████      | 256680/638020 [1:23:59<2:04:46, 50.94it/s]
Processing 0901_2.avi:  38%|███▊      | 250654/657930 [1:19:07<2:08:33, 52.80it/s]
Processing 0902_1.avi:  25%|██▌       | 250208/993203 [1:24:20<4:10:28, 49.44it/s]
Processing 0902_2.avi:  44%|████▎     | 132912/304560 [38:36<49:50, 57.39it/s]  
Processing 0903_1.avi:  33%|███▎      | 254298/768281 [1:20:49<2:43:22, 52.44it/s]
Processing 0903_2.avi:  33%|███▎      | 176071/527670 [52:31<1:44:53, 55.87it/s]
Processing 0904_1.avi:  39%|███▉      | 256275/660187 [1:20:47<2:07:20, 52.86it/s]
Processing 0904_2.avi:  40%|███▉      | 252635/635340 [1:16:55<1:56:32, 54.73it/s]
Processing 0905_1.avi:  35%|███▌      | 255040/727574 [1:20:45<2:29:38, 52.63it/s]
Processing 0905_2.avi:  33%|███▎      | 184253/566130 [54:22<1:52:41, 56.48it/s]
Processing 0906_1.avi:  28%|██▊       | 251672/911614 [1:21:29<3:33:41, 51.47it/s] 
Processing 0906_2.avi:  42%|████▏     | 162457/384330 [48:24<1:06:06, 55.94it/s]
Processing 