In [None]:
# ResNet50 모델로 ImageDataGenerator를 이용해서 glaucoma 데이터 분류모델 학습시키기

In [9]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # GPU마다 메모리 증가를 허용
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("✔ GPU 메모리 동적 할당 설정 완료.")
    except RuntimeError as e:
        print(e)

✔ GPU 메모리 동적 할당 설정 완료.


In [11]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

resnet_model = ResNet50(input_shape=(224,224,3), include_top=False)
resnet_model.trainable = True

model = Sequential()
model.add(resnet_model)
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.summary()

In [12]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_gen = ImageDataGenerator(rotation_range=20, 
                               width_shift_range=0.2, height_shift_range=0.2,
                               horizontal_flip=True)
train_data = train_gen.flow_from_directory('./datasets/glaucoma/train', target_size=(224,224),
                                           batch_size=32, class_mode='sparse')

Found 1394 images belonging to 3 classes.


In [13]:
test_gen = ImageDataGenerator()
test_data = test_gen.flow_from_directory('./datasets/glaucoma/test', target_size=(224,224), 
                                        batch_size=32, class_mode='sparse')

Found 150 images belonging to 3 classes.


In [14]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam',
              metrics=['accuracy'])

In [None]:
# pip install scipy
model.fit(train_data, validation_data=test_data, epochs=20)

In [None]:
model.save('./models/glaucoma_model.h5')

In [None]:
from tensorflow.keras.preprocessing import image
img = image.load_img('test.png', target_size=(224,224))
x = image.img_to_array(img).reshape(-1, 224, 224, 3)

pred = model.predict(x)
print(pred)

In [None]:
import numpy as np
pred_class = np.argmax(pred, axis=1)
print(pred_class)