<a href="https://colab.research.google.com/github/euns-tory/AIFFEL_quest_cr/blob/main/practice/jellyfish_ensemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
!mkdir -p ~/.kaggle
!mv /content/kaggle.json ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

In [8]:
!mkdir -p /content/jellyfish

In [9]:
!kaggle datasets download -d anshtanwar/jellyfish-types -p /content/jellyfish

Dataset URL: https://www.kaggle.com/datasets/anshtanwar/jellyfish-types
License(s): Attribution 4.0 International (CC BY 4.0)


In [10]:
!unzip /content/jellyfish/jellyfish-types.zip -d /content/jellyfish/

Archive:  /content/jellyfish/jellyfish-types.zip
  inflating: /content/jellyfish/Moon_jellyfish/01.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/02.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/03.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/04.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/05.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/06.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/08.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/10.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/12.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/13.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/14.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/16.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/17.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/18.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/21.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/22.jpg  
  inflating: /content/jellyfish/Moon_jellyfish/24.jpg  

In [11]:
!pip install tensorflow keras



In [12]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50, EfficientNetB0, MobileNetV2
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# 데이터 경로 설정
data_dir = "/content/jellyfish/"

# 이미지 크기 및 하이퍼파라미터 설정
img_size = (224, 224)
batch_size = 32
epochs = 10

# 데이터 증강 및 로드
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode="categorical",
    subset="training"
)

val_generator = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode="categorical",
    subset="validation"
)

num_classes = len(train_generator.class_indices)

# 기본 모델 생성 함수
def create_model(base_model):
    base_model.trainable = False  # 사전 학습된 가중치 고정
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    output = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=output)
    return model

# 개별 모델 정의
base_models = {
    "ResNet50": ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)),
    "EfficientNetB0": EfficientNetB0(weights="imagenet", include_top=False, input_shape=(224, 224, 3)),
    "MobileNetV2": MobileNetV2(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
}

models = {}
for name, base in base_models.items():
    model = create_model(base)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss="categorical_crossentropy", metrics=["accuracy"])
    models[name] = model

# 개별 모델 학습
for name, model in models.items():
    print(f"Training {name}...")
    model.fit(train_generator, validation_data=val_generator, epochs=epochs)

# 앙상블 예측 함수 (Soft Voting)
def ensemble_predict(models, data):
    predictions = [model.predict(data) for model in models.values()]
    avg_prediction = np.mean(predictions, axis=0)  # Soft Voting (확률 평균)
    return avg_prediction

# 모델 평가 (앙상블 적용)
val_data, val_labels = next(val_generator)
ensemble_preds = ensemble_predict(models, val_data)
ensemble_labels = np.argmax(ensemble_preds, axis=1)
true_labels = np.argmax(val_labels, axis=1)

accuracy = np.mean(ensemble_labels == true_labels)
print(f"앙상블 모델 정확도: {accuracy * 100:.2f}%")


Found 1504 images belonging to 7 classes.
Found 375 images belonging to 7 classes.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Training ResNet50...


  self._warn_if_super_not_called()


Epoch 1/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 8s/step - accuracy: 0.2448 - loss: 2.0850 - val_accuracy: 0.5200 - val_loss: 1.5534
Epoch 2/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m357s[0m 8s/step - accuracy: 0.4958 - loss: 1.6248 - val_accuracy: 0.5200 - val_loss: 1.5511
Epoch 3/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m346s[0m 7s/step - accuracy: 0.5164 - loss: 1.6297 - val_accuracy: 0.5200 - val_loss: 1.5513
Epoch 4/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m353s[0m 8s/step - accuracy: 0.5204 - loss: 1.5910 - val_accuracy: 0.5200 - val_loss: 1.5494
Epoch 5/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m345s[0m 7s/step - accuracy: 0.5196 - loss: 1.6091 - val_accuracy: 0.5200 - val_loss: 1.5510
Epoch 6/10
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m359s[0m 8s/step - accuracy: 0.5274 - loss: 1.5893 - val_accuracy: 0.5200 - val_loss: 1.5519
Epoch 7/10
[1m47/47[0m [32m━━━━