In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import os

In [None]:
import kagglehub
path = kagglehub.dataset_download('anshtanwar/jellyfish-types')
print("데이터셋 파일 경로:", path)

In [None]:
print("다운로드된 파일 목록:")
print(os.listdir(path))

In [None]:
# 폴더 구조 확인
train_test_valid_path = os.path.join(path, 'Train_Test_Valid')
if os.path.exists(train_test_valid_path):
    print("\nTrain_Test_Valid 폴더 내용:")
    print(os.listdir(train_test_valid_path))

    # 훈련 데이터 경로
    train_path = os.path.join(train_test_valid_path, 'Train')
    if os.path.exists(train_path):
        train_files = pd.DataFrame(os.listdir(train_path), columns=['Files_Name'])
        print("\n훈련 파일:")
        print(train_files.head())
    else:
        print(f"\n'Train' 폴더를 찾을 수 없습니다.")

    # 검증 데이터 경로
    valid_path = os.path.join(train_test_valid_path, 'valid')
    if os.path.exists(valid_path):
        valid_files = pd.DataFrame(os.listdir(valid_path), columns=['Files_Name'])
        print("\n검증 파일:")
        print(valid_files.head())
    else:
        print(f"\n'valid' 폴더를 찾을 수 없습니다.")
else:
    print(f"\n'Train_Test_Valid' 폴더를 찾을 수 없습니다.")

In [None]:
import glob
import numpy as np

def create_dataframe(image_folder):
    files = glob.glob(os.path.join(image_folder, '**', '*.*'), recursive=True)
    np.random.shuffle(files)
    labels = [os.path.basename(os.path.dirname(f)) for f in files]
    return pd.DataFrame({'Image': files, 'Label': labels})

dataframe_train = create_dataframe(train_path)
dataframe_valid = create_dataframe(valid_path)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

dataframe_train['Dataset'] = 'Training'
dataframe_valid['Dataset'] = 'Validation'

combined_df = pd.concat([dataframe_train, dataframe_valid], ignore_index=True)

sns.set_theme(style="whitegrid")

count_data = combined_df.groupby(['Label', 'Dataset']).size().unstack(fill_value=0)
percentage_data = count_data.div(count_data.sum(axis=0), axis=1) * 100

print(percentage_data)

palette = sns.color_palette("pastel", len(count_data))

plt.figure(figsize=(12, 6))
ax = sns.countplot(data=combined_df, x="Label", hue="Dataset", order=count_data.sum(axis=1).sort_values(ascending=False).index, palette="pastel")

for p in ax.patches:
    height = p.get_height()
    percentage = f'{(height / len(combined_df)) * 100:.2f}%'
    ax.annotate(percentage,
                (p.get_x() + p.get_width() / 2., height),
                ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 8), textcoords='offset points')

plt.xticks(rotation=45, ha='right', fontsize=12)

plt.title("Label Distribution in Training and Validation Datasets", fontsize=16, pad=20)
plt.xlabel("Labels", fontsize=14, labelpad=10)
plt.ylabel("Count", fontsize=14, labelpad=10)

sns.despine()

plt.tight_layout()
plt.show()

In [None]:
import tensorflow as tf

batch_size = 16
target_size = (224,224)
train= tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    validation_split=None,
    image_size=target_size,
    batch_size=batch_size,
)
validation= tf.keras.preprocessing.image_dataset_from_directory(
    valid_path,
    validation_split=None,
    image_size=target_size,
    batch_size=batch_size,
)

In [None]:
class_labels = train.class_names

plt.figure(figsize=(15, 10))

shown_classes = set()

for images, labels in train.take(1):
    for i in range(len(images)):
        class_name = class_labels[labels[i]]
        if class_name not in shown_classes:
            ax = plt.subplot(1, 6, len(shown_classes) + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_name)
            plt.axis("off")
            shown_classes.add(class_name)

        if len(shown_classes) == len(class_labels):
            break

plt.tight_layout()
plt.show()

In [None]:
import wandb
from wandb.integration.keras import WandbCallback
from datetime import datetime

run_name = f"vgg16-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
wandb.init(
    project="jellyfish-classification",
    name=run_name,
    )

config = wandb.config
config.learning_rate = 0.001
config.batch_size = 16
config.epochs = 10
config.optimizer = "adam"
config.model = "VGG16"


In [None]:
# 모델 정의

base_model = tf.keras.applications.VGG16(
    include_top=False,  # 완전 연결 레이어 제외
    weights='imagenet',
    input_shape=(224, 224, 3)  # 입력 이미지 크기
)

base_model.trainable = False

# 새로운 분류 헤드 추가
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(len(train.class_names), activation='softmax')  # 해파리 클래스 수에 맞게
])

In [None]:
# 컴파일

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
# wandb 콜백 커스텀(VGG16 기반에서 기본 콜백함수는 오류가 자주 생긴다고 함)

class MyWandbCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            wandb.log(logs, step=epoch)


In [None]:
# fit
with tf.device('/GPU:0'):
    history = model.fit(
        train,
        validation_data=validation,
        epochs=10,
        batch_size=16,
        callbacks=[MyWandbCallback()]
    )
