In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from ultralytics import YOLO
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np

# YOLO 모델 로드
yolo_model = YOLO(r"C:\Users\ime203\Desktop\Graduation\runs\detect\Epochs80test\weights\best.pt")

# ImageClassifier 모델 정의 및 로드
class ImageClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ImageClassifier, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        return x

# 모델 인스턴스 및 로드된 가중치 설정
num_classes = 3  # Adjust based on your dataset
image_classifier = ImageClassifier(num_classes=num_classes)
image_classifier.load_state_dict(torch.load(r"C:\Users\ime203\Desktop\Graduation\resnet18_image_classifier.pth"))
image_classifier.eval()

transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.Lambda(lambda img: img.convert('RGB')),  # 이미지를 RGB로 변환
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 이미지 분류 함수
def classify_image(image):
    image = transform(image).unsqueeze(0)  # 배치 차원 추가
    outputs = image_classifier(image)
    _, preds = torch.max(outputs, 1)
    class_names = ["0", "1", "2"]  # Adjust based on your dataset
    return class_names[preds.item()]

# 객체 감지 함수
def detect_objects(image):
    results = yolo_model(image)
    result_image = results[0].plot()  # 첫 번째 결과를 시각화
    return Image.fromarray(result_image), results[0].boxes


def detect_and_classify(image):
    try:
        # 객체 감지
        detected_image, boxes = detect_objects(image)
        # 이미지 분류
        classification_result = classify_image(image)
        
        # 특정 조건에 따라 텍스트 추가 (예: 분류 결과가 "1" 또는 "2"일 경우)
        if classification_result in ["1", "2"]:
            try:
                draw = ImageDraw.Draw(detected_image)
                font = ImageFont.truetype("arial.ttf", 36)
                
                if classification_result == "1":
                    text = "Level 1 Accident"
                elif classification_result == "2":
                    text = "Level 2 Accident"
                
                bbox = draw.textbbox((0, 0), text, font=font)
                textwidth, textheight = bbox[2] - bbox[0], bbox[3] - bbox[1]
                width, height = detected_image.size
                x = width // 2 - textwidth // 2
                y = height // 10 - textheight // 2
                draw.text((x, y), text, font=font, fill=(255, 0, 0))
            except Exception as e:
                print(f"Error drawing text: {e}")

        return detected_image, classification_result
    except Exception as e:
        print(f"Error in detect_and_classify: {e}")
        return image, "Error"


# 비디오 처리를 위한 함수
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        try:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame_rgb)
            detected_image, classification_result = detect_and_classify(pil_image)
            frames.append(cv2.cvtColor(np.array(detected_image), cv2.COLOR_RGB2BGR))
        except Exception as e:
            print(f"Error processing frame: {e}")
    
    cap.release()
    
    # 비디오 쓰기
    if frames:
        height, width, layers = frames[0].shape
        size = (width, height)
        output_path = 'output_video.mp4'
        out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, size)

        for frame in frames:
            out.write(frame)
        out.release()

        return output_path
    else:
        return "Error: No frames were processed."

interface = gr.Interface(
    fn=process_video,
    inputs=gr.Video(),
    outputs="file",  # 파일 다운로드 링크 제공
    title="YOLO Object Detection and Image Classification in Videos",
    description="Upload a video to detect objects and classify images within the video."
)

interface.launch()




Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


