In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [None]:
from typing import List, Tuple, Dict
import math
from models.cnn import CNN
from dataloader.cifar10 import DataLoader
from utils import helper

In [None]:
myint = tf.int32
myfloat = tf.float32

In [None]:
config = helper.load_config('./config/bn_both_dropout4.yaml')
config

In [None]:
datasource = DataLoader()

In [None]:
print(datasource.data.shape)
print(datasource.labels.shape)
print(datasource.test_labels.shape)
print(datasource.num_step(config['batch_size']))

In [None]:
def horizontal_flip(img: np.array, rate: float=0.5):
    if rate > np.random.rand():
        return img[:, ::-1, :]
    return img

# network

In [None]:
cnn = CNN(config)

In [None]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=cnn.y, logits=cnn.logits)
probs_op = tf.nn.softmax(cnn.logits)
loss_op = tf.reduce_mean(crossent)
optimizer = tf.train.AdamOptimizer(config['learning_rate'])
global_step = tf.train.get_or_create_global_step()
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
    train_op = optimizer.minimize(loss_op, global_step=global_step)
correct = tf.equal(cnn.predicted_classes, cnn.y)
acc_op = tf.reduce_mean(tf.cast(correct, myfloat))

In [None]:
with tf.name_scope('train'):
    smr_loss = tf.summary.scalar('loss', loss_op)
    smr_acc = tf.summary.scalar('accuracy', acc_op)
    merged_summary = tf.summary.merge([smr_loss, smr_acc])

with tf.name_scope('test'):
    test_smr_acc = tf.summary.scalar('accuracy', acc_op)

In [None]:
from datetime import datetime
now = datetime.now()
logdir_base = 'logs/'

In [None]:
logdir = logdir_base + now.strftime("%Y%m%d-%H%M%S") + "/"
batch_size = config['batch_size']
tf_config = tf.ConfigProto(
    allow_soft_placement=True,
    gpu_options=tf.GPUOptions(
        allow_growth=True
    ))
with tf.Session(config=tf_config) as sess:
    writer = tf.summary.FileWriter(logdir, sess.graph)
    sess.run(tf.global_variables_initializer())
    
    for i in range(config['num_epoch']):
        step_size = datasource.num_step(batch_size)
        for s in range(step_size):
            data, labels = datasource.next_batch(batch_size)
            data = [horizontal_flip(d) for d in data]
            fd = {
                cnn.x: data,
                cnn.y: labels,
                cnn.is_training: True
            }
            loss, _, acc, smr, step = sess.run([loss_op, train_op, acc_op, merged_summary, global_step], feed_dict=fd)
            if step % config['num_print_step'] == 0:
                writer.add_summary(smr, step)
                #print('{} steps, train accuracy: {:.6f}, loss: {:.6f}'.format(step, acc, loss))
                predicted_classes, probs = sess.run([cnn.predicted_classes, probs_op], feed_dict={
                    cnn.x: datasource.test_data,
                    cnn.is_training: False
                })
                f_predicted_classes, f_probs = sess.run([cnn.predicted_classes, probs_op], feed_dict={
                    cnn.x: [horizontal_flip(d, 1.0) for d in datasource.test_data],
                    cnn.is_training: False
                })
                probs = np.max(probs, axis=1)
                f_probs = np.max(f_probs, axis=1)
                predicted_labels = np.where(probs >= f_probs, predicted_classes, f_predicted_classes)
                #predicted_labels = np.where(probs_label, predicted_classes, f_predicted_classes)
                test_acc = np.mean((predicted_labels == datasource.test_labels).astype(np.float32))
                test_acc_smr = tf.Summary()
                test_acc_smr.value.add(tag='test/accuracy', simple_value=test_acc)
                writer.add_summary(test_acc_smr, step)
        print('{} steps, test accuracy:  {:.4f}, loss: {:.4f} ({}/{} epochs)'.format(step, test_acc, loss, i, config['num_epoch']))

In [None]:
testimg = datasource.test_data[2]
plt.imshow(testimg)

In [None]:
testimg2 = testimg[:,::-1,:]
plt.imshow(testimg2)