<a href="https://colab.research.google.com/github/dhckdduf/first-repository/blob/main/vgg16%26bbox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import os
import zipfile
import urllib.request
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models
import numpy as np
from PIL import Image
import time

In [None]:
# Google Colab 환경 설정 (GPU 사용 가능 여부 확인)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
# 데이터 다운로드 및 압축 해제
dataset_url = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
dataset_path = "/content/cats_and_dogs_filtered.zip"
dataset_extract_path = "/content/dataset"

if not os.path.exists(dataset_extract_path):
    print("Downloading dataset...")
    urllib.request.urlretrieve(dataset_url, dataset_path)
    with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
        zip_ref.extractall("/content")
    print("Dataset extracted successfully.")
else:
    print("Dataset already exists.")

In [None]:
# 평가 문항 출력
print("평가 문항")
print("1. VGG16 모델을 구현할 수 있는가? - 이미지로 제시된 VGG16 모델을 코드로 구현하였다.")
print("2. 다양한 방법을 사용하여 성능을 향상시켰는가? - 다양한 방법을 사용하여 accuracy 53% 이상을 달성하였다.")
print("3. 다양한 이미지와 모델을 사용하여 Object Detection을 수행하였는가? - 제시된 이미지 외의 다른 이미지에 Object Detection을 수행하였고, 1가지 이상의 사전 학습된 모델을 사용하여 결과를 비교하였다.")


In [None]:
# VGG16 기반 모델 생성
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
for layer in base_model.layers:
    layer.trainable = False

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1, activation='sigmoid')
])


In [None]:
# 모델 컴파일
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              metrics=['accuracy'])

In [None]:
# 데이터 경로 설정
train_data_path = "/content/cats_and_dogs_filtered/train"
validation_data_path = "/content/cats_and_dogs_filtered/validation"

# 데이터 증강 및 로드
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

train_data_gen = train_datagen.flow_from_directory(
    train_data_path, target_size=(224, 224), batch_size=16, class_mode='binary')

val_data_gen = val_datagen.flow_from_directory(
    validation_data_path, target_size=(224, 224), batch_size=16, class_mode='binary')


In [None]:
# 모델 학습
history = model.fit(train_data_gen, epochs=10, validation_data=val_data_gen)

# 학습 결과 출력
print("최종 학습 정확도:", history.history['accuracy'][-1])
print("최종 검증 정확도:", history.history['val_accuracy'][-1])

In [None]:
# Object Detection 모델 로드 및 실행
module_handle = "https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1"
detector = hub.load(module_handle).signatures['default']

def load_img(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    return img

def display_image(image):
    fig = plt.figure(figsize=(10, 10))
    plt.grid(False)
    plt.imshow(image)
    plt.show()

def run_detector(detector, image_path):
    img = load_img(image_path)
    converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
    start_time = time.time()
    result = detector(converted_img)
    end_time = time.time()
    result = {key:value.numpy() for key,value in result.items()}
    print("Found %d objects." % len(result["detection_scores"]))
    print("Inference time: ", end_time-start_time)
    display_image(img.numpy())

In [None]:
# 샘플 이미지 테스트
image_url = "https://upload.wikimedia.org/wikipedia/commons/3/3f/JPEG_example_flower.jpg"
image_path = "/content/sample.jpg"
urllib.request.urlretrieve(image_url, image_path)
run_detector(detector, image_path)
