<a href="https://colab.research.google.com/github/juhumkwon/YOLO/blob/main/SimpleYOLO(coco_2017_%EC%82%AC%EC%9A%A9).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# TensorFlow 기반 YOLO 모델

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/drive')

# 간단한 YOLO 모델 정의
class SimpleYOLO(tf.keras.Model):
    def __init__(self, num_classes):
        super(SimpleYOLO, self).__init__()
        self.backbone = tf.keras.Sequential([
            tf.keras.layers.Conv2D(16, (3, 3), strides=1, padding="same", activation="relu"),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(32, (3, 3), strides=1, padding="same", activation="relu"),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), strides=1, padding="same", activation="relu"),
            tf.keras.layers.MaxPooling2D((2, 2)),
        ])
        self.flatten = tf.keras.layers.Flatten()
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation="relu"),
            tf.keras.layers.Dense(num_classes * 5, activation="linear")  # 4 bbox + 1 confidence per class
        ])

    def call(self, inputs):
        x = self.backbone(inputs)
        x = self.flatten(x)
        return self.fc(x)

# COCO 데이터셋 로드 및 전처리
def preprocess_coco(image, annotations):
    image = tf.image.resize(image, (64, 64)) / 255.0  # 이미지를 64x64 크기로 리사이즈하고 정규화
    # annotations를 모델에 맞게 변환 필요 (예제에서는 생략)
    return image, tf.zeros((80 * 5,))  # 80 클래스 기준, 더미 타겟 반환

def get_coco_dataset(batch_size=8):
    dataset, _ = tfds.load("coco/2017", split="train", data_dir="./tensorflow_datasets", with_info=True)
    dataset = dataset.map(lambda x: preprocess_coco(x["image"], x["objects"]))
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# 모델 학습 루프
def train_model():
    num_classes = 80  # COCO 데이터셋 클래스 수
    model = SimpleYOLO(num_classes)

    # 손실 함수 및 옵티마이저 설정
    loss_fn = tf.keras.losses.MeanSquaredError()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    # 데이터셋 로드
    dataset = get_coco_dataset()

    # 모델 컴파일 및 학습
    model.compile(optimizer=optimizer, loss=loss_fn)
    model.fit(dataset, epochs=5)

    # 학습된 모델 저장
    model.save("simple_yolo.h5")
    return model

# 이미지 추론 함수
def predict_image(model, image_path):
    # 이미지 로드 및 전처리
    image = tf.keras.utils.load_img(image_path, target_size=(64, 64))
    input_image = tf.keras.utils.img_to_array(image) / 255.0
    input_image = tf.expand_dims(input_image, axis=0)  # 배치 차원 추가

    # 모델 추론
    predictions = model(input_image).numpy()
    predictions = predictions.reshape(-1, 5)  # 클래스별로 [x, y, w, h, confidence] 형태로 변환

    # 신뢰도가 높은 바운딩 박스 필터링
    threshold = 0.5
    filtered_boxes = predictions[predictions[:, 4] > threshold]

    return filtered_boxes, image

# 바운딩 박스를 그리는 함수
def draw_boxes(image, boxes):
    plt.imshow(image)
    for box in boxes:
        x, y, w, h, confidence = box
        x1, y1 = int((x - w / 2) * image.size[0]), int((y - h / 2) * image.size[1])
        x2, y2 = int((x + w / 2) * image.size[0]), int((y + h / 2) * image.size[1])
        plt.gca().add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='r', facecolor='none'))
        plt.text(x1, y1 - 10, f"Conf: {confidence:.2f}", color="red", fontsize=8)
    plt.axis("off")
    plt.show()

# 실행
if __name__ == "__main__":
    model = train_model()

    # 추론할 이미지 경로
    image_path = "/content/drive/My Drive/Data/bird.jpg"  # 추론할 이미지 경로
    boxes, processed_image = predict_image(model, image_path)
    draw_boxes(processed_image, boxes)


Mounted at /content/drive




Downloading and preparing dataset 25.20 GiB (download: 25.20 GiB, generated: Unknown size, total: 25.20 GiB) to tensorflow_datasets/coco/2017/1.1.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]