In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from glob import glob
from tqdm.auto import tqdm
import tensorflow as tf
from tensorflow.keras import layers, models, utils
from sklearn.model_selection import train_test_split

import imgaug.augmenters as iaa
import imgaug as ia

#### Download dataset

In [None]:
# download from google drive
!pip install --upgrade gdown
!gdown --id '1w0ldWmLkbaypadIDiFexve3bW1IJuosc' --output pneumonia.zip

Downloading...
From: https://drive.google.com/uc?id=1w0ldWmLkbaypadIDiFexve3bW1IJuosc
To: /content/pneumonia.zip
100% 1.22G/1.22G [00:09<00:00, 128MB/s] 


In [None]:
# unzip dataset file
!unzip -q pneumonia.zip

#### Prepare dataset


In [None]:
IMG_SIZE = 200
BATCH_SIZE = 32

all_class = ['normal', 'bacteria', 'virus']
class_map = {cls:i for i,cls in enumerate(all_class)} #  'normal':0, 'bacteria': 1, 'virus':2
class_map

{'bacteria': 1, 'normal': 0, 'virus': 2}

In [None]:
# read all paths
img_paths_train = glob('pneumonia-kaggle/train/*/*.jpeg')

In [None]:
img_paths_train, img_paths_val = train_test_split(img_paths_train, 
                                                  test_size=0.2)

In [None]:
# number of images
len(img_paths_train), len(img_paths_val)

(4172, 1044)

In [None]:
class DataGenerator(utils.Sequence):
    def __init__(self, paths, batch_size, img_size, shuffle=True, aug=False):
        self.paths = paths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.paths)) # [0, 1, 2 ~ 4171]
        all_class = ['normal', 'bacteria', 'virus']
        self.class_map = {cls:i for i,cls in enumerate(all_class)}
        self.num_classes = len(self.class_map)
        self.img_size = img_size
        # Augmentation
        self.aug = aug
        self.seq = iaa.Sequential([
            iaa.Fliplr(0.5), # 50% horizontal flip
            iaa.Affine(
                rotate=(-10, 10), # random rotate -45 ~ +45 degree
                shear=(-16,16), # random shear -16 ~ +16 degree
                scale={"x": (0.8, 1.2), "y": (0.8, 1.2)} # scale x, y: 80%~120%
            ),
        ])
        self.on_epoch_end()

    def __len__(self):
        'number of batches per epoch'
        return int(np.ceil(len(self.paths) / self.batch_size))

    def __getitem__(self, batch_index):
        'Generate one batch of data'
        # Generate indexes of the batch
        idxs = self.indexes[batch_index * self.batch_size:(batch_index + 1) * self.batch_size]
        # Find list of IDs
        batch_paths = [self.paths[i] for i in idxs]
        # Generate data
        X, y = self.__data_generation(batch_paths)
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, paths):
        """
        Generates data containing batch_size samples
        """
        x = np.empty((len(paths), self.img_size, self.img_size, 3), dtype=np.float32)
        y = np.empty((len(paths)))

        for i, path in enumerate(paths):
            img = cv2.imread(path)[:,:,::-1]
            # img preprocess
            img = cv2.resize(img, (self.img_size, self.img_size))
            img = img / 255. # normalize to 0~1

            # read class label
            cls = path.split(os.sep)[-2].lower()
            if cls == 'pneumonia':
                # get filename
                filename = path.split(os.sep)[-1]
                # get pneumonia subclass
                cls_pneumonia = filename.split('_')[1] 
                cls_idx = class_map[cls_pneumonia]
            # for normal class
            else:
                cls_idx = class_map[cls]

            x[i] = img
            y[i] = cls_idx
        # one-hot encoding
        y = tf.keras.utils.to_categorical(y, num_classes=self.num_classes)

        if self.aug:
            x = self.augmentation(x)
            
        return x, y
    def augmentation(self, imgs):
        return self.seq.augment_images(imgs)

In [None]:
gen_train = DataGenerator(img_paths_train, BATCH_SIZE, IMG_SIZE, 
                          shuffle=True,
                          aug=True)
gen_val = DataGenerator(img_paths_val, BATCH_SIZE, IMG_SIZE, 
                        shuffle=False,
                        aug=False)

In [None]:
# Python

print(len(gen_train), len(gen_val))
# gen_train.__len__()

x, y = gen_train[1]
# gen_train.__getitem__(0)

131 33


In [None]:
x.shape, y.shape

((32, 200, 200, 3), (32, 3))

### Build model

In [None]:
base_model = tf.keras.applications.EfficientNetB0(include_top=False,
                                                  weights='imagenet',
                                                  input_shape=(IMG_SIZE, IMG_SIZE, 3))
x = layers.GlobalAveragePooling2D()(base_model.output) 
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(3, activation='softmax')(x)

for l in base_model.layers:
    l.trainable = False

model = models.Model(base_model.input, x)

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5


In [None]:
model.summary()

#### Training

In [None]:
model.compile(tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])

In [None]:
model.fit(gen_train,  # x_train, y_train
          epochs=100000,
          callbacks=[
                     tf.keras.callbacks.EarlyStopping(patience=10),
                     tf.keras.callbacks.ModelCheckpoint('./best.h5', save_best_only=True,),
                     tf.keras.callbacks.ModelCheckpoint('./last.h5', save_best_only=False,),
                     tf.keras.callbacks.ReduceLROnPlateau(patience=3, factor=0.3),
          ],
          validation_data=gen_val) # (x_val, y_val)

Epoch 1/100000

  layer_config = serialize_layer_fn(layer)


Epoch 2/100000

KeyboardInterrupt: ignored

#### Training logs

In [None]:
history = logs.history
min_loss_epoch = np.argmin(history['val_loss'])
print('val loss ', history['val_loss'][min_loss_epoch])
print('val acc', history['val_categorical_accuracy'][min_loss_epoch])

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
plt.plot(history['categorical_accuracy'])
plt.plot(history['val_categorical_accuracy'])
plt.legend(['categorical_accuracy', 'val_categorical_accuracy'])
plt.title('categorical_accuracy')
plt.subplot(1,2,2)
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.title('loss')
plt.show()

#### Metrics

In [None]:
# load best model
model = tf.keras.models.load_model('./best.h5')

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

y_pred = np.array([])
y_true = np.array([])
for x_val, y_val in tqdm(gen_val):
    pred = model.predict(x_val)
    y_pred = np.append(y_pred, np.argmax(pred, axis=-1))
    y_true = np.append(y_true, np.argmax(y_val, axis=-1))

  0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
y_pred[:10], y_true[:10]

(array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 array([1., 2., 2., 0., 1., 1., 1., 2., 1., 2.]))

In [None]:
# classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       285
         1.0       0.47      1.00      0.64       495
         2.0       0.00      0.00      0.00       264

    accuracy                           0.47      1044
   macro avg       0.16      0.33      0.21      1044
weighted avg       0.22      0.47      0.31      1044



  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
# Confusion matrix:
#   row: Ground truth
#   column: predict
print(confusion_matrix(y_true, y_pred))

[[  0 285   0]
 [  0 495   0]
 [  0 264   0]]


#### Save model

In [None]:
# 1. save whole model and weights
model.save('my_model.h5')
model2 = models.load_model('my_model.h5')

In [None]:
# 2. save and load weights only
model.save_weights('my_model_weights.h5')

# Rebuild model before load weights !
model2 = build_model()
model2.load_weights('my_model_weights.h5')
