# Pretrained CNN Inference Example (TensorFlow version)
### Dataset: cats_and_dogs_filtered
### Models: VGG16, ResNet50, MobileNetV2 (TensorFlow Hub / Keras Applications)

In [1]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image_dataset_from_directory

# 1) Dataset (cats & dogs)

In [None]:
!wget https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
!unzip -qq cats_and_dogs_filtered.zip

In [3]:
!ls

cats_and_dogs_filtered	cats_and_dogs_filtered.zip  sample_data


In [7]:
base_dir = os.path.join("/content/", "cats_and_dogs_filtered")
train_dir = os.path.join(base_dir, "train")
val_dir   = os.path.join(base_dir, "validation")

In [8]:
print(val_dir)

/content/cats_and_dogs_filtered/validation


In [9]:
!ls /content/cats_and_dogs_filtered/train

cats  dogs


In [10]:
IMG_SIZE   = (224, 224)
BATCH_SIZE = 1  # inference 용도로 1 권장

val_ds = image_dataset_from_directory(
    val_dir, image_size=IMG_SIZE, batch_size=BATCH_SIZE, shuffle=True
)

Found 1000 files belonging to 2 classes.


In [11]:
class_names = val_ds.class_names
AUTOTUNE = tf.data.AUTOTUNE
val_ds = val_ds.prefetch(AUTOTUNE)

# 2) Models & preprocess map

In [12]:
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2
from tensorflow.keras.applications.vgg16 import preprocess_input as vgg_pre
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_pre
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_pre
from tensorflow.keras.applications.imagenet_utils import decode_predictions

def get_model_and_preprocess(model_name: str):
    """Return (keras_model, preprocess_fn) by name."""
    print(model_name)
    name = model_name.strip().lower()
    print(name)
    if name == "vgg16":
        return VGG16(weights="imagenet"), vgg_pre
    if name == "resnet50":
        return ResNet50(weights="imagenet"), resnet_pre
    if name == "mobilenetv2":
        return MobileNetV2(weights="imagenet"), mobilenet_pre
    raise ValueError(f"Unsupported model name: {model_name}")

# 3) Pretty summary printer

In [13]:
def run_summary(model_name):
    model, _ = get_model_and_preprocess(model_name)
    print(f"\n===== Summary: {model.name} =====")
    model.summary()

# 4) Inference (by model name)

In [14]:
def run_inference(model_name, num_sample_show):
    model, pre_fn = get_model_and_preprocess(model_name)
    print(f"\n\n================ Inference: {model.name} ================")
    cnt = 0
    for images, labels in val_ds:
        # batch_size=1이므로 [0]만 사용
        img_uint8 = images[0].numpy().astype("uint8")
        gt_idx = int(labels[0].numpy())

        # preprocess & predict
        x = tf.cast(images, tf.float32)  # [1, H, W, 3]
        x = pre_fn(x)
        preds = model.predict(x, verbose=0)  # (1, 1000)
        decoded = decode_predictions(preds, top=3)[0]

        # show image
        plt.figure(figsize=(4,4))
        plt.imshow(img_uint8)
        plt.title(f"GT: {class_names[gt_idx]}")
        plt.axis("off")
        plt.show()

        print("Top-3 Predictions:")
        for (_, name, score) in decoded:
            print(f" - {name:15s}: {score*100:.2f}%")
        print("-" * 40)

        cnt += 1
        if cnt >= num_sample_show:
            break


# 5. 모델별 추론 테스트

### VGG16 with Batch-Norm.

In [None]:
run_summary("VGG16")
run_inference("VGG16", 5)

### ResNet50

In [None]:
run_summary("ResNet50")
run_inference("ResNet50", 5)

### MobileNetV2

In [None]:
run_summary("MobileNetV2")
run_inference("MobileNetV2", 5)