In [12]:
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf

from tensorflow.python.keras.applications import vgg16, vgg19
from tensorflow.python.keras.applications.vgg16 import preprocess_input
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras import layers, models, Model, optimizers

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, plot_confusion_matrix

In [2]:
train_data_dir = "../data_split/train/"
test_data_dir = "../data_split/validation/"

category_names = sorted(os.listdir('../data_split/train/'))
nb_categories = len(category_names)
nb_categories

4

In [3]:
img_height=400
img_width=518
batch_size=32

In [4]:
conv_base = vgg16.VGG16(weights='imagenet', include_top=False, pooling = 'max', input_shape=(img_height,img_width,3))

In [5]:
for layer in conv_base.layers :
    print(layer, layer.trainable)

<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x2b5ec008a350> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5e471a63d0> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ebbdc4b90> True
<tensorflow.python.keras.layers.pooling.MaxPooling2D object at 0x2b5ec0328490> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec3ef10d0> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec03288d0> True
<tensorflow.python.keras.layers.pooling.MaxPooling2D object at 0x2b5ec3ef89d0> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec3efd610> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec3f03d50> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec3f068d0> True
<tensorflow.python.keras.layers.pooling.MaxPooling2D object at 0x2b5ec3f11f10> True
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x2b5ec3f16

In [6]:
model = models.Sequential()
model.add(conv_base)
model.add(layers.Dense(2, activation = 'softmax'))
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
vgg16 (Functional)           (None, 512)               14714688  
_________________________________________________________________
dense (Dense)                (None, 2)                 1026      
Total params: 14,715,714
Trainable params: 14,715,714
Non-trainable params: 0
_________________________________________________________________


In [7]:
# rescale
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

print('total no. of images for trainig : ')
train_generator = train_datagen.flow_from_directory(train_data_dir, 
                                                    target_size=(img_height, img_width), batch_size=batch_size,
                                                    class_mode="categorical")

print('total no. of images for test :')
test_generator = test_datagen.flow_from_directory(test_data_dir, 
                                                  target_size=(img_height, img_width), batch_size=batch_size,
                                                  class_mode="categorical", shuffle=False)

total no. of images for trainig : 
Found 224 images belonging to 2 classes.
total no. of images for test :
Found 56 images belonging to 2 classes.


In [14]:
learning_rate = 5e-5
epochs = 100

checkpoint = ModelCheckpoint("cell_classifier.h5", monitor = 'val_acc', verbose = 1, save_best_only = True,
                            save_weights_only=False, mode='auto', save_freq=1)
model.compile(loss ="categorical_crossentropy", optimizer = tf.keras.optimizers.Adam(lr=learning_rate, clipnorm=1.),
              metrics=['acc'])

In [16]:
model.fit(train_generator, epochs=epochs, validation_data=test_generator)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100

KeyboardInterrupt: 