In [1]:
import os
import sys

In [2]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import ml.generators.mip_generator as generator
import tensorflow as tf

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [4]:
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

In [5]:
TrainGen = generator.MipGenerator
ValGen = generator.MipGenerator
train_gen = TrainGen(dims=(220, 220, 3),
                          batch_size=16,
                          augment_data=True,
                          extend_dims=False,
                          validation=False)
val_gen = ValGen(dims=(220, 220, 3),
                          batch_size=16,
                          augment_data=True,
                          extend_dims=False,
                          validation=True)

In [6]:
from keras.models import Model
from keras.applications.vgg16 import VGG16
from keras.layers import Dense, GlobalAveragePooling2D
from keras.callbacks import ModelCheckpoint
from keras.optimizers import adam

In [7]:
def add_new_last_layer(base_model, nb_classes):
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(2048, activation='relu')(x)
    predictions = Dense(nb_classes, activation='sigmoid')(x)
    model = Model(input=base_model.input, output=predictions)
    return model

In [8]:
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(220, 220, 3))
for layer in base_model.layers:
    layer.trainable = False
finished_model = add_new_last_layer(base_model, 1)
finished_model.compile(optimizer=adam(lr=0.0001),
                      loss='binary_crossentropy',
                      metrics=['accuracy'])
finished_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 220, 220, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 220, 220, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 220, 220, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 110, 110, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 110, 110, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 110, 110, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 55, 55, 128)       0         
__________

  


In [None]:
mc_callback = ModelCheckpoint(filepath='tmp/vgg16_pretrained_weights.hdf5', verbose=1)
finished_model.fit_generator(
    generator=train_gen.generate(),
    steps_per_epoch=train_gen.get_steps_per_epoch(),
    validation_data=val_gen.generate(),
    validation_steps=val_gen.get_steps_per_epoch(),
    epochs=20,
    callbacks=[mc_callback],
    verbose=1,
    max_queue_size=1)

0
Epoch 1/20
Loaded entire batch.
(16, 220, 220, 3)
1
 1/87 [..............................] - ETA: 8:25 - loss: 0.7028 - acc: 0.4375Loaded entire batch.
(16, 220, 220, 3)
2
 2/87 [..............................] - ETA: 5:57 - loss: 0.6924 - acc: 0.5312Loaded entire batch.
(16, 220, 220, 3)
3
 3/87 [>.............................] - ETA: 5:57 - loss: 0.6955 - acc: 0.4375Loaded entire batch.
(16, 220, 220, 3)
4
 4/87 [>.............................] - ETA: 5:48 - loss: 0.7008 - acc: 0.3906Loaded entire batch.
(16, 220, 220, 3)
5
 5/87 [>.............................] - ETA: 5:44 - loss: 0.6991 - acc: 0.4250Loaded entire batch.
(16, 220, 220, 3)
6
 6/87 [=>............................] - ETA: 5:44 - loss: 0.6993 - acc: 0.4271Loaded entire batch.
(16, 220, 220, 3)
7
 7/87 [=>............................] - ETA: 5:35 - loss: 0.6963 - acc: 0.4732Loaded entire batch.
(16, 220, 220, 3)
8
 8/87 [=>............................] - ETA: 5:30 - loss: 0.6961 - acc: 0.4844Loaded entire batch.
(16, 2



Loaded entire batch.
(16, 220, 220, 3)
13
13/87 [===>..........................] - ETA: 5:02 - loss: 0.7020 - acc: 0.4952Loaded entire batch.
(16, 220, 220, 3)
14
14/87 [===>..........................] - ETA: 4:57 - loss: 0.6928 - acc: 0.5179Loaded entire batch.
(16, 220, 220, 3)
15
15/87 [====>.........................] - ETA: 4:53 - loss: 0.6932 - acc: 0.5208Loaded entire batch.
(16, 220, 220, 3)
16
16/87 [====>.........................] - ETA: 4:49 - loss: 0.6922 - acc: 0.5234Loaded entire batch.
(16, 220, 220, 3)
17
17/87 [====>.........................] - ETA: 4:44 - loss: 0.6863 - acc: 0.5368Loaded entire batch.
(16, 220, 220, 3)
18
18/87 [=====>........................] - ETA: 4:40 - loss: 0.6811 - acc: 0.5486Loaded entire batch.
(16, 220, 220, 3)
19
19/87 [=====>........................] - ETA: 4:34 - loss: 0.6774 - acc: 0.5559Loaded entire batch.
(16, 220, 220, 3)
20
20/87 [=====>........................] - ETA: 4:29 - loss: 0.6751 - acc: 0.5594Loaded entire batch.
(16, 220, 2

(16, 220, 220, 3)
57
(16, 220, 220, 3)
58
(16, 220, 220, 3)
59
(16, 220, 220, 3)
60
(16, 220, 220, 3)
61
(16, 220, 220, 3)
62
(16, 220, 220, 3)
63
(16, 220, 220, 3)
64
(16, 220, 220, 3)
65
(16, 220, 220, 3)
66
(16, 220, 220, 3)
67
(16, 220, 220, 3)
68
(16, 220, 220, 3)
69
(16, 220, 220, 3)
70
(16, 220, 220, 3)
71
(16, 220, 220, 3)
72
(16, 220, 220, 3)
73
(16, 220, 220, 3)
74
(16, 220, 220, 3)
75
(16, 220, 220, 3)
76
(16, 220, 220, 3)
77
(16, 220, 220, 3)
78
(16, 220, 220, 3)
79
(16, 220, 220, 3)
80
(16, 220, 220, 3)
81
(16, 220, 220, 3)
82
(16, 220, 220, 3)
83
(16, 220, 220, 3)
84
(16, 220, 220, 3)
85
(16, 220, 220, 3)
86
(16, 220, 220, 3)
0
1
Loaded entire batch.
(16, 220, 220, 3)
Loaded entire batch.
(16, 220, 220, 3)
2
Loaded entire batch.
(16, 220, 220, 3)
3
Loaded entire batch.
(16, 220, 220, 3)
4
Loaded entire batch.
(16, 220, 220, 3)
5
Loaded entire batch.
(16, 220, 220, 3)
6
Loaded entire batch.
(16, 220, 220, 3)
7
Loaded entire batch.
(16, 220, 220, 3)
8
Loaded entire batch.
(