In [1]:
import tensorflow as tf

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [3]:
import os
import datetime
import numpy as np
import pandas as pd
import warnings
import network
import split_extract
from multiprocessing import Pool
from tqdm import tqdm
from params import dresden_images_root, dresden_csv, ins_root, ins_csv, patch_span, \
        patch_num, ins_patches, ins_patches_db
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

In [4]:
warnings.filterwarnings("ignore")

if not os.path.exists(ins_patches):
    os.makedirs(ins_patches)

data = pd.read_csv(dresden_csv)
data = data[(data['brand'] == 'Canon') & (data['instance'] == 0)]

brand_model = ['_'.join([brand, model]) for brand, model in zip(data['brand'],data['model'])]
# file_name = [os.path.join(dresden_images_root, name) for name in data['filename']]
df = pd.DataFrame({'brand_model':brand_model, 'file_name':data['filename']})
df.to_csv(ins_csv, index=False)

In [5]:
images_db = pd.read_csv(ins_csv)
model_list = np.unique(images_db['brand_model'])
img_list = list(images_db['file_name'])

if not os.path.exists(ins_patches_db):
    train_list, val_list, test_list = split_extract.split(img_list, model_list, ins_patches_db)
else:
    patches_db = np.load(ins_patches_db, allow_pickle=True).item()
    train_list = patches_db['train']
    val_list = patches_db['val']
    test_list = patches_db['test']
    split_extract.split_info(train_list, val_list, test_list, model_list)

Canon_Ixus55 in training set: 145.
Canon_Ixus55 in validation set: 34.
Canon_Ixus55 in test set: 45.

Canon_Ixus70 in training set: 118.
Canon_Ixus70 in validation set: 34.
Canon_Ixus70 in test set: 35.

Canon_PowerShotA640 in training set: 121.
Canon_PowerShotA640 in validation set: 28.
Canon_PowerShotA640 in test set: 39.



In [6]:
print('Collecting image data...')
train_imgs_list = []
val_imgs_list = []
test_imgs_list = []

for img_brand_model,img_path in \
    tqdm(zip(images_db['brand_model'], images_db['file_name'])):
                   
    if img_path in train_list:
        train_imgs_list += [{'data_set':'train',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': ins_patches,
                           'img_root': dresden_images_root
                           }]
        
    elif img_path in val_list:
        val_imgs_list += [{'data_set':'val',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': ins_patches,
                           'img_root': dresden_images_root
                           }]
        
    else:
        test_imgs_list += [{'data_set':'test',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': ins_patches,
                           'img_root': dresden_images_root
                           }]

# num_processes = 12
num_processes = 4
pool = Pool(processes=num_processes)

print('Extracting training patches...')
train_paths = pool.map(split_extract.extract, train_imgs_list)
print('Extracting validation patches...')
val_paths = pool.map(split_extract.extract, val_imgs_list)
print('Extracting testing patches...')
test_paths = pool.map(split_extract.extract, test_imgs_list)
print('Completed.')

0it [00:00, ?it/s]

Collecting image data...


599it [00:00, 44055.34it/s]


Extracting training patches...
Extracting validation patches...
Extracting testing patches...
Completed.


In [7]:
img_height = 256
img_width = 256
batch_size = 64

# Load the training and validation datasets

train_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255, horizontal_flip=True, vertical_flip=True)

validation_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255)

train_data_gen  = train_generator.flow_from_directory(
    directory=r"./instance/patches/train/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

validation_data_gen = validation_generator.flow_from_directory(
    directory=r"./instance/patches/val/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

Found 9600 images belonging to 3 classes.
Found 2400 images belonging to 3 classes.


In [10]:
sgd = tf.optimizers.SGD(lr=0.001, momentum=0.9, decay=0.0005)
model.compile(
    optimizer=sgd, 
    loss='categorical_crossentropy', 
    metrics=['accuracy'])

In [None]:
#           ------------ Train the Model ------------
if not os.path.exists('./instance/saved_model'):
    os.makedirs('./instance/saved_model')
    
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

ConstrainLayer = network.ConstrainLayer(model)
callbacks = [ModelCheckpoint('./instance/saved_model/weights.{epoch:02d}.h5',
    monitor='acc',verbose=1, save_best_only=False,
    save_freq=1), ConstrainLayer, tensorboard_callback]

history = model.fit_generator(generator=train_data_gen, epochs=45, workers=10,
     callbacks=callbacks)  #removed validation data

Epoch 1/45

Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  1/150 [..............................] - ETA: 13:24 - loss: 1.1573 - accuracy: 0.3281
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  2/150 [..............................] - ETA: 7:14 - loss: 1.1643 - accuracy: 0.3281 
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  3/150 [..............................] - ETA: 5:05 - loss: 1.2224 - accuracy: 0.2760
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  4/150 [..............................] - ETA: 4:01 - loss: 1.0564 - accuracy: 0.4453
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  5/150 [>.............................] - ETA: 3:21 - loss: 1.2353 - accuracy: 0.4313
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  6/150 [>.............................] - ETA: 2:54 - loss: 1.4086 - accuracy: 0.4271
Epoch 00001: saving model to ./instance/saved_model/weights.01.h5
  

In [None]:
model.save('./instance/model.h5')