In [1]:
import cv2
from tensorflow import *
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks, layers, metrics, models, optimizers, regularizers
from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout, BatchNormalization
from keras.callbacks import ModelCheckpoint, EarlyStopping
import keras.utils as image

In [2]:
train_path = "archive/chest_xray/train"
val_path = "archive/chest_xray/val"
test_path = "archive/chest_xray/test"

batch_size = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
)

validation_datagen = ImageDataGenerator(
    rescale = 1./255
)

test_datagen = ImageDataGenerator(
    rescale = 1./255
)

train_generator = train_datagen.flow_from_directory(
    train_path,
    target_size = (150, 150),
    batch_size = batch_size,
    class_mode = "binary",
    color_mode = "grayscale"
)

validation_generator = validation_datagen.flow_from_directory(
    val_path,
    target_size = (150, 150),
    batch_size = batch_size,
    class_mode = "binary",
    color_mode = "grayscale"
)

test_generator = test_datagen.flow_from_directory(
    test_path,
    target_size = (150, 150),
    batch_size = batch_size,
    class_mode = "binary",
    color_mode = "grayscale"
)

train_num = train_generator.samples
val_num = validation_generator.samples

for image, label in train_generator:
    print(image.shape)
    print(label.shape)
    break

Found 5216 images belonging to 2 classes.
Found 624 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
(32, 150, 150, 1)
(32,)


In [3]:
model = Sequential()

model.add(Conv2D(256, (3, 3), activation = "relu", input_shape = (150, 150, 1)))
model.add(MaxPooling2D((2, 2)))

model.add(Dropout(0.2))
model.add(Conv2D(128, (3, 3), activation = "relu"))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (3, 3), activation = "relu"))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(32, (3, 3), activation = "relu"))
model.add(MaxPooling2D((2, 2)))

model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(128, activation = "relu"))
model.add(Dense(1, activation = "sigmoid"))

model.compile(
    optimizer = "adam",
    loss = "binary_crossentropy",
    metrics = ["acc"]
)

model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 148, 148, 256)     2560      
                                                                 
 max_pooling2d (MaxPooling2  (None, 74, 74, 256)       0         
 D)                                                              
                                                                 
 dropout (Dropout)           (None, 74, 74, 256)       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 72, 72, 128)       295040    
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 36, 36, 128)       0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 34, 34, 64)        7

In [None]:
epochs = 8

history = model.fit(
    train_generator,
    steps_per_epoch = train_num / batch_size,
    epochs = epochs,
    validation_data = validation_generator,
    validation_steps = val_num / batch_size
)

Epoch 1/8

In [None]:
acc = history.history["acc"]
val_acc = history.history["val_acc"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

plt.plot(epochs, loss, label = "Training loss")
plt.plot(epochs, val_loss, label = "Validation loss")
plt.legend()
plt.figure()

plt.plot(epochs, acc, label = "Training accuracy")
plt.plot(epochs, val_acc, label = "Validation accuracy")
plt.legend()
plt.figure()