In [None]:
%matplotlib inline
import os
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
import pickle

from utils.Param import get_default_param
from utils.eval import fpr, retrieval_recall_K

from network.model_fn import triplet_model_fn
from network.dataset.sem_patchdata import input_fn
from network.dataset.sem_patchdata_ext import input_fn as sem_input_fn
from network.train import TripletEstimator

In [None]:
# set seed for reproduction
np.random.seed(2019)
tf.set_random_seed(2019)

In [None]:
# parameters (adjust as needed)
log_dir = './log/sem'
param = get_default_param(mode='AUSTIN', log_dir=log_dir)

In [None]:
param.data_dir = '/home/sungsooha/Desktop/Data/ftfy/austin'
#param.data_dir = './Data/austin'
param.train_datasets = 'sem' # we will define sem dataset separately
param.test_datasets = None #'human_patch'
param.batch_size = 8 # 64 for v100
param.n_epoch = 100
param.n_triplet_samples = 500000
param.train_log_every   = 100000

test_datasets = None #'scene_patch'

In [None]:
sem_data_dir = '/home/sungsooha/Desktop/Data/ftfy/sem/train'
sem_train_datasets = []
for f in os.listdir(sem_data_dir):
    if os.path.isdir(os.path.join(sem_data_dir,f)):
        sem_train_datasets.append(f)
sem_train_datasets = sorted(sem_train_datasets)
print(sem_train_datasets)

### data pipeline

In [None]:
tf.logging.info("Preparing data pipeline ...")
with tf.device('/cpu:0'), tf.name_scope('input'):
    train_dataset, train_data_sampler = sem_input_fn(
        data_dir=sem_data_dir,
        base_patch_size=param.base_patch_size,
        patches_per_row=10,
        patches_per_col=10,
        batch_size=param.batch_size,
        patch_size=param.patch_size,
        n_channels=param.n_channels
    )
    test_dataset, test_data_sampler = input_fn(
        data_dir=param.data_dir,
        base_patch_size=param.base_patch_size,
        patches_per_row=param.patches_per_row,
        patches_per_col=param.patches_per_col,
        batch_size=param.batch_size,
        patch_size=param.patch_size,
        n_channels=param.n_channels
    )
    test_dataset_2, test_data_sampler_2 = input_fn(
        data_dir=param.data_dir,
        base_patch_size=param.base_patch_size,
        patches_per_row=param.patches_per_row,
        patches_per_col=param.patches_per_col,
        batch_size=param.batch_size,
        patch_size=param.patch_size,
        n_channels=param.n_channels
    )    
    data_iterator = tf.data.Iterator.from_structure(
        train_dataset.output_types,
        train_dataset.output_shapes
    )
    train_dataset_init = data_iterator.make_initializer(train_dataset)
    test_dataset_init = data_iterator.make_initializer(test_dataset)
    test_dataset_init_2 = data_iterator.make_initializer(test_dataset_2)
    batch_data = data_iterator.get_next()

### load data

In [None]:
train_data_sampler.load_dataset(
    dir_name=sem_train_datasets,
    ext='bmp',
    patch_size=param.patch_size,
    n_channels=param.n_channels,
    debug=True
)

In [None]:
if param.test_datasets is not None:
    test_data_sampler.load_dataset(
        dir_name=param.test_datasets,
        ext='bmp',
        patch_size=param.patch_size,
        n_channels=param.n_channels,
        debug=True
    )

In [None]:
if test_datasets is not None:
    test_data_sampler_2.load_dataset(
        dir_name=test_datasets,
        ext='bmp',
        patch_size=param.patch_size,
        n_channels=param.n_channels,
        debug=True
    )

### compute data statistics

In [None]:
tf.logging.info('Loading training stats: %s' % param.train_datasets)
try:
    file = open(os.path.join(param.log_dir, 'stats_%s.pkl' % param.train_datasets), 'rb')
    mean, std = pickle.load(file)
except:
    tf.logging.info('Calculating train data stats (mean, std)')
    mean, std = train_data_sampler.generate_stats()
    pickle.dump(
        [mean, std], 
        open(os.path.join(param.log_dir, 'stats_%s.pkl' % param.train_datasets), 'wb')
    )
tf.logging.info('Mean: {:.5f}'.format(mean))
tf.logging.info('Std : {:.5f}'.format(std))
train_data_sampler.normalize_data(mean, std)

if param.test_datasets is not None:
    test_data_sampler.normalize_data(mean, std)

if test_datasets is not None:
    test_data_sampler_2.normalize_data(mean, std)

### build model

In [None]:
tf.logging.info("Creating the model ...")
anchors, positives, negatives = batch_data
spec = triplet_model_fn(
    anchors, positives, negatives, n_feats=param.n_features,
    mode='TRAIN', cnn_name=param.cnn_name, loss_name=param.loss_name,
    optimizer_name=param.optimizer_name,
    margin=param.margin,
    use_regularization_loss=param.use_regularization,
    learning_rate=param.learning_rate,
    shared_batch_layers=True,
    name='triplet-net'
)
estimator = TripletEstimator(spec, save_dir=param.log_dir)

### Training

In [None]:
K=[1, 5, 10, 20, 30]

all_loss = [] # avg. loss over epochs
train_fpr95 = [] # fpr95 with training dataset
train_retrieval = [] # retrieval with training dataset
test_fpr95 = []
test_retrieval = []
test_fpr95_2 = []
test_retrieval_2 = []

tf.logging.info('='*50)
tf.logging.info('Start training ...')
tf.logging.info('='*50)
for epoch in range(param.n_epoch):
    tf.logging.info('-'*50)
    tf.logging.info('TRAIN {:d}, {:s} start ...'.format(epoch, param.train_datasets))
    train_data_sampler.set_mode(0)
    #train_data_sampler.set_n_triplet_samples(param.n_triplet_samples)
    train_data_sampler.set_n_triplet_samples(5000)
    loss = estimator.train(
        dataset_initializer=train_dataset_init,
        log_every=param.train_log_every
    )
    all_loss.append(loss)
    tf.logging.info('-'*50)

    # for evaluation with training dataset
    tf.logging.info('-'*50)
    tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, param.train_datasets))
    train_data_sampler.set_mode(1)
    train_data_sampler.set_n_matched_pairs(5000)
    test_match = estimator.run_match(train_dataset_init)
    fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
    train_fpr95.append(fpr95)
    tf.logging.info('FPR95: {:.5f}'.format(fpr95))
    
    train_data_sampler.set_mode(2)
    test_rrr = estimator.run_retrieval(train_dataset_init)
    rrr = retrieval_recall_K(
        features=test_rrr.features,
        labels=train_data_sampler.get_labels(test_rrr.index),
        is_query=test_rrr.scores,
        K=K
    )[0]
    train_retrieval.append(rrr)
    tf.logging.info('Retrieval: {}'.format(rrr))
    tf.logging.info('-'*50)
    
    break
    
    # for evaluation with test dataset
    if param.test_datasets is not None:
        tf.logging.info('-'*50)
        tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, param.test_datasets))
        test_data_sampler.set_mode(1)
        #test_data_sampler.set_n_matched_pairs(1000)
        test_match = estimator.run_match(test_dataset_init)
        fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
        test_fpr95.append(fpr95)
        tf.logging.info('FPR95: {:.5f}'.format(fpr95))

        test_data_sampler.set_mode(2)
        test_rrr = estimator.run_retrieval(test_dataset_init)
        rrr = retrieval_recall_K(
            features=test_rrr.features,
            labels=test_data_sampler.get_labels(test_rrr.index),
            is_query=test_rrr.scores,
            K=K
        )[0]
        test_retrieval.append(rrr)
        tf.logging.info('Retrieval: {}'.format(rrr))
        tf.logging.info('-'*50)
    
    # for evaluation with test dataset
    if test_datasets is not None:
        tf.logging.info('-'*50)
        tf.logging.info('TEST {:d}, {:s} start ...'.format(epoch, test_datasets))
        test_data_sampler_2.set_mode(1)
        #test_data_sampler.set_n_matched_pairs(1000)
        test_match = estimator.run_match(test_dataset_init_2)
        fpr95 = fpr(test_match.labels, test_match.scores, recall_rate=0.95)
        test_fpr95_2.append(fpr95)
        tf.logging.info('FPR95: {:.5f}'.format(fpr95))

        test_data_sampler_2.set_mode(2)
        test_rrr = estimator.run_retrieval(test_dataset_init_2)
        rrr = retrieval_recall_K(
            features=test_rrr.features,
            labels=test_data_sampler_2.get_labels(test_rrr.index),
            is_query=test_rrr.scores,
            K=K
        )[0]
        test_retrieval_2.append(rrr)
        tf.logging.info('Retrieval: {}'.format(rrr))
        tf.logging.info('-'*50)
    
    # save checkpoint
    if epoch % param.save_every == 0 or epoch+1 == param.n_epoch:
        estimator.save(param.project_name, global_step=epoch)
    
    #if epoch > 10:
    #    break

### Plot results

In [None]:
plt.plot(all_loss)

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].plot(train_fpr95)
ax[1].plot(test_fpr95)
ax[2].plot(test_fpr95_2)

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].plot(train_retrieval)
ax[1].plot(test_retrieval)
ax[2].plot(test_retrieval_2)

In [None]:
# save results
out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, param.train_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(train_fpr95),
    retrieval=np.asarray(train_retrieval)
)
np.save(out_dir, metric)

out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, param.test_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(test_fpr95),
    retrieval=np.asarray(test_retrieval)
)
np.save(out_dir, metric)

out_dir = os.path.join(param.log_dir, 'metrics_{}_{}.npy'.format(
    param.train_datasets, test_datasets
))
metric = dict(
    loss=np.array(all_loss),
    fpr95=np.array(test_fpr95_2),
    retrieval=np.asarray(test_retrieval_2)
)
np.save(out_dir, metric)