In [1]:
%matplotlib inline
import os
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
import pickle

from utils.Param import get_default_param
from network.dataset.ftfy_patchdata import input_fn
#from network.model_fn import triplet_model_fn
#from network.dataset.sem_patchdata import input_fn
#from network.dataset.sem_patchdata_ext import input_fn as sem_input_fn
#from network.train import TripletEstimator

In [2]:
# set seed for reproduction
np.random.seed(2019)
tf.set_random_seed(2019)

In [3]:
# parameters (adjust as needed)
log_dir = './log/sem_ftfy'
param = get_default_param(mode='AUSTIN', log_dir=log_dir)

In [4]:
param.data_dir = '/home/sungsooha/Desktop/Data/ftfy/sem/train'
#param.data_dir = './Data/austin'
param.train_datasets = 'sem' # we will define sem dataset separately
param.test_datasets = None #'human_patch'
param.batch_size = 8 # 64 for v100
param.n_epoch = 100
param.n_triplet_samples = 500000
param.train_log_every   = 100000

In [5]:
sem_data_dir = '/home/sungsooha/Desktop/Data/ftfy/sem/train'
sem_train_datasets = []
for f in os.listdir(sem_data_dir):
    if os.path.isdir(os.path.join(sem_data_dir,f)):
        sem_train_datasets.append(f)
sem_train_datasets = sorted(sem_train_datasets)
print(sem_train_datasets)

['set_001', 'set_003', 'set_004', 'set_005', 'set_007', 'set_008', 'set_009', 'set_010', 'set_011', 'set_012', 'set_017', 'set_018', 'set_020', 'set_027', 'set_033', 'set_034', 'set_037', 'set_038', 'set_043', 'set_045', 'set_058', 'set_059', 'set_061', 'set_065']


### data pipeline

In [6]:
tf.logging.info("Preparing data pipeline ...")
with tf.device('/cpu:0'), tf.name_scope('input'):
    train_dataset, train_data_sampler = input_fn(
        base_dir=sem_data_dir,
        cellsz=param.cellsz,
        n_parameters=param.n_parameters,
        src_size=param.src_size,
        tar_size=param.tar_size,
        n_channels=1,
        batch_size=param.batch_size
    )
    data_iterator = tf.data.Iterator.from_structure(
        train_dataset.output_types,
        train_dataset.output_shapes
    )
    train_dataset_init = data_iterator.make_initializer(train_dataset)
    batch_data = data_iterator.get_next()

INFO:tensorflow:Preparing data pipeline ...


### load data

In [7]:
train_data_sampler.load_dataset(
    data_dirs=sem_train_datasets,
    src_dir='sources', tar_dir='patches',
    src_ext='bmp', src_size=param.src_size, n_src_channels=param.n_channels, 
    src_per_col=10, src_per_row=10,
    tar_ext='bmp', tar_size=param.tar_size, n_tar_channels=param.n_channels,
    tar_per_col=10, tar_per_row=10,
    debug=True
)

Loading dataset set_001/patches: 100%|██████████| 48/48 [00:00<00:00, 98.95it/s]
Loading dataset set_001/sources: 100%|██████████| 5/5 [00:00<00:00, 17.19it/s]
Loading dataset set_003/patches: 100%|██████████| 16/16 [00:00<00:00, 95.29it/s]
Loading dataset set_003/sources: 100%|██████████| 2/2 [00:00<00:00, 17.85it/s]
Loading dataset set_004/patches: 100%|██████████| 23/23 [00:00<00:00, 99.41it/s]
Loading dataset set_004/sources: 100%|██████████| 3/3 [00:00<00:00, 18.04it/s]
Loading dataset set_005/patches: 100%|██████████| 11/11 [00:00<00:00, 89.14it/s]
Loading dataset set_005/sources: 100%|██████████| 2/2 [00:00<00:00, 18.62it/s]
Loading dataset set_007/patches: 100%|██████████| 6/6 [00:00<00:00, 74.76it/s]
Loading dataset set_007/sources: 100%|██████████| 1/1 [00:00<00:00, 15.06it/s]
Loading dataset set_008/patches: 100%|██████████| 25/25 [00:00<00:00, 94.40it/s]
Loading dataset set_008/sources: 100%|██████████| 3/3 [00:00<00:00, 15.31it/s]
Loading dataset set_009/patches: 100%|████

INFO:tensorflow:datamap : (106010, 3)
INFO:tensorflow:bboxes  : (106010, 4)
INFO:tensorflow:targets : (106010, 128, 128, 1)
INFO:tensorflow:sources : (10601, 256, 256, 1)


### compute data statistics

In [8]:
tf.logging.info('Loading training stats: %s' % param.train_datasets)
try:
    file = open(os.path.join(param.log_dir, 'stats_%s.pkl' % param.train_datasets), 
                'rb')
    mean, std = pickle.load(file)
except:
    tf.logging.info('Calculating train data stats (mean, std)')
    mean, std = train_data_sampler.generate_stats()
    pickle.dump(
        [mean, std], 
        open(os.path.join(param.log_dir, 'stats_%s.pkl' % param.train_datasets), 
             'wb')
    )
tf.logging.info('Mean: {:.5f}'.format(mean))
tf.logging.info('Std : {:.5f}'.format(std))
train_data_sampler.normalize_data(mean, std)

INFO:tensorflow:Loading training stats: sem
INFO:tensorflow:Calculating train data stats (mean, std)
INFO:tensorflow:Mean: 0.30739
INFO:tensorflow:Std : 0.21744


Normalizing targets: 100%|██████████| 106010/106010 [00:02<00:00, 37152.83it/s]
Normalizing sources: 100%|██████████| 10601/10601 [00:01<00:00, 7418.14it/s]


### build model

In [None]:
tf.logging.info("Creating the model ...")
sources, targets, labels, bboxes = batch_data
spec = ftfy_model_fn(
    sources, targets, labels, bboxes, 
    # for triplet feature extractor
    feature_mode='TEST',
    feature_trainable=False,
    feature_name='triplet-net'
    cnn_name=param.cnn_name,
    shared_batch_layers=True,
    # for ftfy
    ftfy_mode='TRAIN',  
    optimizer_name=param.optimizer_name,
    learning_rate=param.learning_rate,
)
#estimator = TripletEstimator(spec, save_dir=param.log_dir)

### Training

In [None]:
K=[1, 5, 10, 20, 30]

all_loss = [] # avg. loss over epochs
train_fpr95 = [] # fpr95 with training dataset
train_retrieval = [] # retrieval with training dataset
test_fpr95 = []
test_retrieval = []
test_fpr95_2 = []
test_retrieval_2 = []

tf.logging.info('='*50)
tf.logging.info('Start training ...')
tf.logging.info('='*50)
for epoch in range(param.n_epoch):
    tf.logging.info('-'*50)
    tf.logging.info('TRAIN {:d}, {:s} start ...'.format(epoch, param.train_datasets))
    train_data_sampler.set_mode(0)
    #train_data_sampler.set_n_triplet_samples(param.n_triplet_samples)
    train_data_sampler.set_n_triplet_samples(5000)
    loss = estimator.train(
        dataset_initializer=train_dataset_init,
        log_every=param.train_log_every
    )
    all_loss.append(loss)
    tf.logging.info('-'*50)

    # for evaluation with training dataset
    tf.logging.info('-'*50)
    tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, param.train_datasets))
    train_data_sampler.set_mode(1)
    train_data_sampler.set_n_matched_pairs(5000)
    test_match = estimator.run_match(train_dataset_init)
    fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
    train_fpr95.append(fpr95)
    tf.logging.info('FPR95: {:.5f}'.format(fpr95))
    
    train_data_sampler.set_mode(2)
    test_rrr = estimator.run_retrieval(train_dataset_init)
    rrr = retrieval_recall_K(
        features=test_rrr.features,
        labels=train_data_sampler.get_labels(test_rrr.index),
        is_query=test_rrr.scores,
        K=K
    )[0]
    train_retrieval.append(rrr)
    tf.logging.info('Retrieval: {}'.format(rrr))
    tf.logging.info('-'*50)
    
    break
    
    # for evaluation with test dataset
    if param.test_datasets is not None:
        tf.logging.info('-'*50)
        tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, param.test_datasets))
        test_data_sampler.set_mode(1)
        #test_data_sampler.set_n_matched_pairs(1000)
        test_match = estimator.run_match(test_dataset_init)
        fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
        test_fpr95.append(fpr95)
        tf.logging.info('FPR95: {:.5f}'.format(fpr95))

        test_data_sampler.set_mode(2)
        test_rrr = estimator.run_retrieval(test_dataset_init)
        rrr = retrieval_recall_K(
            features=test_rrr.features,
            labels=test_data_sampler.get_labels(test_rrr.index),
            is_query=test_rrr.scores,
            K=K
        )[0]
        test_retrieval.append(rrr)
        tf.logging.info('Retrieval: {}'.format(rrr))
        tf.logging.info('-'*50)
    
    # for evaluation with test dataset
    if test_datasets is not None:
        tf.logging.info('-'*50)
        tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, test_datasets))
        test_data_sampler_2.set_mode(1)
        #test_data_sampler.set_n_matched_pairs(1000)
        test_match = estimator.run_match(test_dataset_init_2)
        fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
        test_fpr95_2.append(fpr95)
        tf.logging.info('FPR95: {:.5f}'.format(fpr95))

        test_data_sampler_2.set_mode(2)
        test_rrr = estimator.run_retrieval(test_dataset_init_2)
        rrr = retrieval_recall_K(
            features=test_rrr.features,
            labels=test_data_sampler_2.get_labels(test_rrr.index),
            is_query=test_rrr.scores,
            K=K
        )[0]
        test_retrieval_2.append(rrr)
        tf.logging.info('Retrieval: {}'.format(rrr))
        tf.logging.info('-'*50)
    
    # save checkpoint
    if epoch % param.save_every == 0 or epoch+1 == param.n_epoch:
        estimator.save(param.project_name, global_step=epoch)
    
    #if epoch > 10:
    #    break

### Plot results

In [None]:
plt.plot(all_loss)

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].plot(train_fpr95)
ax[1].plot(test_fpr95)
ax[2].plot(test_fpr95_2)

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].plot(train_retrieval)
ax[1].plot(test_retrieval)
ax[2].plot(test_retrieval_2)

In [None]:
# save results
out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, param.train_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(train_fpr95),
    retrieval=np.asarray(train_retrieval)
)
np.save(out_dir, metric)

out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, param.test_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(test_fpr95),
    retrieval=np.asarray(test_retrieval)
)
np.save(out_dir, metric)

out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, test_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(test_fpr95_2),
    retrieval=np.asarray(test_retrieval_2)
)
np.save(out_dir, metric)