# Fine tune a la red de PlantVillage 
Para el dataset de arroz
https://www.kaggle.com/minhhuy2810/rice-diseases-image-dataset

In [1]:
%matplotlib inline

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import numpy as np
import scipy 
import pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
import IPython.display as display
import time

keras = tf.keras
print(tf.__version__)

2.0.0


In [3]:
data_dir = './dataset/rice/Labelled/'
data_dir = pathlib.Path(data_dir)

In [4]:
N = 0  # total files
for dirpath, dirnames, filenames in os.walk(data_dir):    
    dirpath = dirpath.split("/")[-1]
    if dirpath != ".ipynb_checkpoints":
        N_c = len(filenames)
        N += N_c
        print( dirpath+ ": -> " + str(N_c))
print( "Total Files " + str(N) )

Labelled: -> 0
LeafBlast: -> 779
Healthy: -> 1488
Hispa: -> 565
BrownSpot: -> 523
Total Files 3355


In [5]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,
                                                                  validation_split=0.3,                            
                                                                  shear_range=0.2,
                                                                  zoom_range=0.4,
                                                                  horizontal_flip=True,
                                                                  rotation_range=20
                                                                 )

In [6]:
image_count = len(list(data_dir.glob('*/*.jpg'))) + len(list(data_dir.glob('*/*.png')))
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil((image_count*0.85)/BATCH_SIZE)
int(STEPS_PER_EPOCH)

90

In [7]:
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != ".ipynb_checkpoints"])
CLASS_NAMES

array(['LeafBlast', 'Healthy', 'Hispa', 'BrownSpot'], dtype='<U9')

In [17]:
seed = 3

In [18]:
train_data_gen = image_generator.flow_from_directory(directory = str(data_dir),
                                                     batch_size = BATCH_SIZE,
                                                     shuffle = True,
                                                     seed = seed,
                                                     target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                     subset = "training",
                                                     classes = list(CLASS_NAMES))

Found 2351 images belonging to 4 classes.


In [19]:
valid_data_gen = image_generator.flow_from_directory(directory = str(data_dir),
                                                     batch_size = BATCH_SIZE,
                                                     shuffle = True,
                                                     seed = seed,
                                                     target_size = (IMG_HEIGHT, IMG_WIDTH),
                                                     subset = "validation",
                                                     classes = list(CLASS_NAMES))

Found 1004 images belonging to 4 classes.


## MobileNetV2

In [10]:
model_m = tf.keras.models.load_model('models/plant_village_MobileNetV2_trained_01.h5')

In [11]:
prediction_layer = keras.layers.Dense(len(CLASS_NAMES), activation="softmax", 
                                      kernel_initializer=keras.initializers.he_normal(seed=None))
model_m.pop()
model_m.add(prediction_layer)

In [None]:
#for layer in model_m.layers[0].layers[:70]: 
#    layer.trainable = False

In [12]:
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=2, min_lr=0.0001)
early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4)
filepathM = "models/plant_village_rice_MobileNetV2.{epoch:02d}-{val_accuracy:.2f}.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = filepathM, 
                                                          save_best_only=True , monitor='val_accuracy', mode='max')

In [13]:
base_learning_rate = 0.0003
model_m.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [14]:
total_epochs = 30
start_time = time.time()
history_m = model_m.fit_generator(train_data_gen,
                         epochs=total_epochs,
                         steps_per_epoch=train_data_gen.__len__(),
                         validation_data=valid_data_gen,
                         validation_steps= valid_data_gen.__len__(),
                         callbacks = [reduce_lr , checkpoint]
                        )
duration = time.time() - start_time
print('took: ' + str(duration/60))

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30

KeyboardInterrupt: 

## Xception

In [14]:
model_x = tf.keras.models.load_model('models/plant_village_Xception_trained_01.h5')
prediction_layer_x = keras.layers.Dense(len(CLASS_NAMES), activation="softmax", 
                                      kernel_initializer=keras.initializers.he_normal(seed=4))
model_x.pop()
model_x.add(prediction_layer_x)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=2, min_lr=0.0001, verbose = 1)
early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4)
filepath_x = "models/plant_village_rice_Xception.{epoch:02d}-{val_accuracy:.2f}.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = filepath_x, 
                                                          save_best_only=True , monitor='val_accuracy', mode='max')

In [15]:
base_learning_rate = 0.001
#base_learning_rate = 0.0003
model_x.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [20]:
total_epochs = 30
start_time = time.time()
history_x = model_x.fit_generator(train_data_gen,
                         epochs=total_epochs,
                         steps_per_epoch=train_data_gen.__len__(),
                         validation_data=valid_data_gen,
                         validation_steps= valid_data_gen.__len__(),
                         callbacks = [reduce_lr , checkpoint]
                        )
duration = time.time() - start_time
print('took: ' + str(duration/60))

Epoch 1/30




Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30

KeyboardInterrupt: 