### Big image classification test

class가 자전거인 경우에는 자전거가 포함된 이미지를 1/3 비율로 가져오고, 랜덤하게 선택한 다른 클래스의 이미지를 2/3 비율로 가져와서 테스트를 진행해보았습니다. 이 과정에서 자전거가 아닌 이미지는 클래스 0으로 예측하고, 자전거가 포함된 이미지는 클래스1로 예측하게 됩니다

In [2]:
import cv2
from tensorflow import keras
from keras.models import load_model
from tensorflow.keras.models import load_model
import numpy as np
from joblib import load
import tensorflow as tf
from vit_keras import vit, layers
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report

# 모델 예측 성능 체크 완료!!!!!! 잘 예측됨 !!!!!!!!
# 우선 먼저, 이미지 예측 되나 확인부터 해보자. 하나의 이미지 가져와서.
vit_model = keras.models.load_model('../vit_model.h5')

# hyperparameter
BATCH_SIZE = 32
ROWS = 224
COLS = 224
input_shape = (ROWS, COLS, 3) # 이미지의 높이, 너비가 각각 224 pixel이고, 채널 수가 3인 RGB이미지임.

image_path = '../Data/Big Image Data'

# 이미지 데이터 전처리
train_data = ImageDataGenerator(rescale=1/255.0, # 이미지 픽셀 값 0~1 normalization
                                zoom_range=0.05, # 이미지 확대/ 축소 범위 -> 다양한 시각적 변화 도입
                                horizontal_flip=True, # 수평으로 뒤집기 -> 이미지의 좌우 대칭성 고려
                                validation_split=0.2)


# 훈련 데이터 생성기
train_generator = train_data.flow_from_directory(directory=image_path,
                                                 target_size=(224,224),
                                                 color_mode='rgb',
                                                 batch_size=BATCH_SIZE,
                                                 class_mode='categorical',
                                                 subset='training',
                                                 shuffle=True,
                                                 seed=42
                                                 )

# 검증 데이터 생성기
valid_generator = train_data.flow_from_directory(directory=image_path,
                                                 target_size=(224,224),
                                                 color_mode='rgb',
                                                 batch_size=BATCH_SIZE,
                                                 class_mode='categorical',
                                                 subset='validation',
                                                 shuffle=True,
                                                 seed=42
                                                 )


Found 9388 images belonging to 12 classes.
Found 2342 images belonging to 12 classes.


In [4]:
def create_test_generator(image_path, class_name, exclude_classes):
  # 각 라벨별 테스트 생성기 함수
  # 라벨 포함 생성기
  test_class_generator = train_data.flow_from_directory(
    directory=image_path,
    target_size=(224, 224),
    color_mode='rgb',
    batch_size=10,
    classes=[class_name],
    class_mode='categorical',
    shuffle=True
  )

  # 라벨 제외 생성기
  test_non_class_generator = train_data.flow_from_directory(
      directory=image_path,
      target_size=(224, 224),
      color_mode='rgb',
      batch_size=10,
      classes=[class_name for class_name in train_generator.class_indices.keys() if class_name not in exclude_classes],
      class_mode='categorical',
      shuffle=True
  )

  return test_class_generator, test_non_class_generator

In [5]:
test_bicycle_generator, test_non_bicycle_generator = create_test_generator(image_path, 'Bicycle', 'Bicycle')
test_bridge_generator, test_non_bridge_generator = create_test_generator(image_path, 'Bridge', 'Bridge')
test_bus_generator, test_non_bus_generator = create_test_generator(image_path, 'Bus', 'Bus')
test_car_generator, test_non_car_generator = create_test_generator(image_path, 'Car', 'Car')
test_chimney_generator, test_non_chimney_generator = create_test_generator(image_path, 'Chimney', 'Chimney')
test_crosswalk_generator, test_non_crosswalk_generator = create_test_generator(image_path, 'Crosswalk', 'Crosswalk')
test_hydrant_generator, test_non_hydrant_generator = create_test_generator(image_path, 'Hydrant', 'Hydrant')
test_motorcycle_generator, test_non_motorcycle_generator = create_test_generator(image_path, 'Motorcycle', 'Motorcycle')
test_palm_generator, test_non_palm_generator = create_test_generator(image_path, 'Palm', 'Palm')
test_stair_generator, test_non_stair_generator = create_test_generator(image_path, 'Stair', 'Stair')
test_traffic_light_generator, test_non_traffic_light_generator = create_test_generator(image_path, 'Traffic Light', 'Traffic Light')

Found 780 images belonging to 1 classes.
Found 10950 images belonging to 11 classes.
Found 533 images belonging to 1 classes.
Found 11197 images belonging to 11 classes.
Found 1209 images belonging to 1 classes.
Found 10521 images belonging to 11 classes.
Found 3558 images belonging to 1 classes.
Found 8172 images belonging to 11 classes.
Found 124 images belonging to 1 classes.
Found 11606 images belonging to 11 classes.
Found 1240 images belonging to 1 classes.
Found 10490 images belonging to 11 classes.
Found 952 images belonging to 1 classes.
Found 10778 images belonging to 11 classes.
Found 81 images belonging to 1 classes.
Found 11649 images belonging to 11 classes.
Found 911 images belonging to 1 classes.
Found 10819 images belonging to 11 classes.
Found 211 images belonging to 1 classes.
Found 11519 images belonging to 11 classes.
Found 791 images belonging to 1 classes.
Found 10939 images belonging to 11 classes.


In [6]:
def load_images_from_generator(generator, num_images):
    images = []
    labels = []
    num_loaded = 0

    while num_loaded < num_images:
        batch_images, batch_labels = next(generator)
        images.append(batch_images)
        labels.append(batch_labels)
        num_loaded += len(batch_images)

    images = np.concatenate(images)[:num_images]
    labels = np.concatenate(labels)[:num_images]

    return images, labels

In [7]:
bicycle_images, bicycle_labels = load_images_from_generator(test_bicycle_generator, 200)
non_bicycle_images, non_bicycle_labels = load_images_from_generator(test_non_bicycle_generator, 400)

In [8]:
bridge_images, bridge_labels = load_images_from_generator(test_bridge_generator, 200)
non_bridge_images, non_bridge_labels = load_images_from_generator(test_non_bridge_generator, 400)

In [9]:
bus_images, bus_labels = load_images_from_generator(test_bus_generator, 200)
non_bus_images, non_bus_labels = load_images_from_generator(test_non_bus_generator, 400)

In [10]:
car_images, car_labels = load_images_from_generator(test_car_generator, 200)
non_car_images, non_car_labels = load_images_from_generator(test_non_car_generator, 400)

In [11]:
chimney_images, chimney_labels = load_images_from_generator(test_chimney_generator, 200)
non_chimney_images, non_chimney_labels = load_images_from_generator(test_non_chimney_generator, 400)

In [12]:
cro_images, cro_labels = load_images_from_generator(test_crosswalk_generator, 200)
non_cro_images, non_cro_labels = load_images_from_generator(test_non_crosswalk_generator, 400)

In [13]:
hydrant_images, hydrant_labels = load_images_from_generator(test_hydrant_generator, 200)
non_hydrant_images, non_hydrant_labels = load_images_from_generator(test_non_hydrant_generator, 400)

In [14]:
moto_images, moto_labels = load_images_from_generator(test_motorcycle_generator, 81)
non_moto_images, non_moto_labels = load_images_from_generator(test_non_motorcycle_generator, 160)

In [15]:
palm_images, palm_labels = load_images_from_generator(test_palm_generator, 200)
non_palm_images, non_palm_labels = load_images_from_generator(test_non_palm_generator, 400)

In [16]:
stair_images, stair_labels = load_images_from_generator(test_stair_generator, 200)
non_stair_images, non_stair_labels = load_images_from_generator(test_non_stair_generator, 400)

In [17]:
tf_images, tf_labels = load_images_from_generator(test_traffic_light_generator, 200)
non_tf_images, non_tf_labels = load_images_from_generator(test_non_traffic_light_generator, 400)

In [19]:
from sklearn.metrics import classification_report

def evaluate_class_performance(vit_model, target_class, target_images, non_target_images):
    class_indices = train_generator.class_indices

    # Index of the target class
    target_index = class_indices[target_class]

    predictions = []
    true_labels = []

    # For target images
    target_preds = vit_model.predict(target_images)

    for pred in target_preds:
        predicted_class = np.argmax(pred)
        if predicted_class == target_index:
            predictions.append(1)  # target class
        else:
            predictions.append(0)  # non-target class
        true_labels.append(1)  # Always 1 for target class

    # For non-target images
    non_target_preds = vit_model.predict(non_target_images)

    for pred in non_target_preds:
        predicted_class = np.argmax(pred)
        if predicted_class == target_index:
            predictions.append(1)  # target class
        else:
            predictions.append(0)  # non-target class
        true_labels.append(0)  # Always 0 for non-target class

    return classification_report(true_labels, predictions)


0 : 특정 클래스 제외  
1 : 해당 클래스

In [20]:
classes = [
    ('Bicycle', bicycle_images, non_bicycle_images),
    ('Bridge', bridge_images, non_bridge_images),
    ('Bus', bus_images, non_bus_images),
    ('Car', car_images, non_car_images),
    ('Chimney', chimney_images, non_chimney_images),
    ('Crosswalk', cro_images, non_cro_images),
    ('Hydrant', hydrant_images, non_hydrant_images),
    ('Motorcycle', moto_images, non_moto_images),
    ('Palm', palm_images, non_palm_images),
    ('Stair', stair_images, non_stair_images),
    ('Traffic Light', tf_images, non_tf_images)
]

In [21]:
def evaluate_all_classes(vit_model, classes):
    for class_name, target_images, non_target_images in classes:
        report = evaluate_class_performance(vit_model, class_name, target_images, non_target_images)
        print(f"Performance for {class_name.capitalize()}:")
        print(report)
        print("============================================")

evaluate_all_classes(vit_model, classes)


Performance for Bicycle:
              precision    recall  f1-score   support

           0       0.98      1.00      0.99       400
           1       1.00      0.95      0.98       200

    accuracy                           0.98       600
   macro avg       0.99      0.98      0.98       600
weighted avg       0.99      0.98      0.98       600

Performance for Bridge:
              precision    recall  f1-score   support

           0       0.96      1.00      0.98       400
           1       0.99      0.93      0.96       200

    accuracy                           0.97       600
   macro avg       0.98      0.96      0.97       600
weighted avg       0.97      0.97      0.97       600

Performance for Bus:
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       400
           1       0.99      0.98      0.98       200

    accuracy                           0.99       600
   macro avg       0.99      0.99      0.99       600
weig

| Class          | Accuracy | F1-score (Target Class) | F1-score (Non-Target Class) | Count |
|----------------|----------|-------------------------|-----------------------------|-------|
| Bicycle        |   0.99   |      0.98               |      0.99                   | 600   |
| Bridge         |   0.97   |      0.96               |      0.98                   | 600   |
| Bus            |   0.98   |      0.97               |      0.99                   | 600   |
| Car            |   0.95   |      0.93               |      0.96                   | 600   |
| Chimney        |   0.97   |      0.96               |      0.98                   | 600   |
| Crosswalk      |   0.95   |      0.92               |      0.97                   | 600   |
| Hydrant        |   1.00   |      1.00               |      1.00                   | 600   |
| Motorcycle     |   0.99   |      0.98               |      0.99                   | 241   |
| Palm           |   0.97   |      0.95               |      0.98                   | 600   |
| Stair          |   0.97   |      0.96               |      0.98                   | 600   |
| Traffic Light  |   0.95   |      0.93               |      0.97                   | 600   |
| **Average Total**      |   **0.97** | **0.96**  | **0.98** | 6051 |