In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
import os

### CIFAR-10 데이터셋 불러오기

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = tf.keras.utils.to_categorical(y_train, 10), tf.keras.utils.to_categorical(y_test, 10)

In [4]:
gen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True
)

augment_ratio = 1.5
augment_size = int(augment_ratio * x_train.shape[0])

randidx = np.random.randint(x_train.shape[0], size=augment_size)

x_augmented = x_train[randidx].copy()
y_augmented = y_train[randidx].copy()

x_augmented, y_augmented = next(gen.flow(x_augmented, y_augmented, batch_size=augment_size, shuffle=False))

x_train = np.concatenate((x_train, x_augmented))
y_train = np.concatenate((y_train, y_augmented))

s = np.arange(x_train.shape[0])
np.random.shuffle(s)

x_train = x_train[s]
y_train = y_train[s]

### 클래스 이름 특정 지정

In [5]:
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer',
               'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

### 이미 모델이 있는 경우 모델을 불러오기

In [6]:
model_path = "cnn_model.h5"

### CNN 모델링

In [14]:
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)),
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),
    layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='last_conv'),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Flatten(),

    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

### 모델 학습

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=50, validation_data=(x_test, y_test), batch_size=256)
model.save(model_path)
print(f"Model Save : {model_path}")

### CAM 구현

In [16]:
def generate_cam(image, model, class_idx):
    grad_model = tf.keras.models.Model(
        inputs=[model.input],
        outputs=[model.get_layer("last_conv").output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(np.expand_dims(image, axis=0))
        loss = predictions[:, class_idx]

    grads = tape.gradient(loss, conv_outputs)[0]
    weights = tf.reduce_mean(grads, axis=(0, 1))
    cam = tf.reduce_sum(weights * conv_outputs[0], axis=-1)
    cam = np.maximum(cam, 0)
    cam = cam / cam.max()
    return cam

### CAM 시각화

In [None]:
import random
num_images = 10
random_indices = random.sample(range(len(x_test)), num_images)
test_images = x_test[random_indices]
true_labels = y_test[random_indices]

plt.figure(figsize=(20, 30))

for i in range(num_images):
    test_img = test_images[i]
    class_idx = np.argmax(model.predict(np.expand_dims(test_img, axis=0)))
    predicted_class = class_names[class_idx]
    true_class = class_names[np.argmax(true_labels[i])]
    cam = generate_cam(test_img, model, class_idx)
    plt.subplot(num_images, 2, 2 * i + 1)
    plt.title(f"Original: {true_class}\nPredict: {predicted_class}")
    plt.imshow(test_img)
    plt.axis('off')

    plt.subplot(num_images, 2, 2 * i + 2)
    plt.title("Overlay CAM")
    plt.imshow(test_img)
    plt.imshow(cam, alpha=0.5, cmap='jet')
    plt.axis('off')

plt.tight_layout()
plt.show()