In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np

from datetime import datetime, timedelta, timezone

import sys
sys.path.append('../')
from utils import data
from models.resnet50 import Resnet50
from utils import visualization

In [None]:
# const
name = 'fasionClassfictionBN'
JST = timezone(timedelta(hours=+9), 'JST')
now = datetime.now(JST)
nowStr = now.strftime("%Y%m%d_%H%M%S")
showImgCount = 10

imgHeight = 224
imgWidth = 224
imgChannel = 3
imgShape = [-1, imgHeight, imgWidth, imgChannel]
labelSize = 13
tfrecordPath = './img/fashionDataset/tfrecord/dataset224.tfrecord'
trainRatio = 0.7

# hyper parameter
bs = 128
lr = 0.0001
ep = 50

logDir = '../logs/{}/{}/'.format(name, nowStr)
metadataDir = '{}metadata.tsv'.format(logDir)
checkPointDir = '{}images.ckpt'.format(logDir)

In [None]:
# main
with tf.Graph().as_default():
    # load data
    with tf.variable_scope('tfrecord'):
        tfrecord = data.TFRecord(tfrecordPath, labelSize)
        dataset = tfrecord.toDataset()
        # split dataset
        tfrecord = data.TFRecord(tfrecordPath, labelSize)
        dataset = tfrecord.toDataset()
        trainSize, testSize, trainDataset, testDataset = tfrecord.splitDataset(bs)
        trainIteration = trainSize // bs
        testIteration = testSize // bs        
        # TODO: data augmentation

    # make iterator
    with tf.variable_scope('train_data'):
        trainIterator = trainDataset.make_initializable_iterator()
        trainNextBatch = trainIterator.get_next(name='train_next_batch')
        trainIteratorInitOp = trainIterator.initializer
    with tf.variable_scope('test_data'):
        testIterator = testDataset.make_initializable_iterator()
        testNextBatch = testIterator.get_next(name='test_next_batch')
        testIteratorInitOp = testIterator.initializer

    
    x = tf.placeholder("float", [None, imgHeight, imgWidth, imgChannel], name='x')
    y = tf.placeholder("float", [None, labelSize], name='label')

    network = Resnet50(x, labelSize=labelSize)
    cost = network.loss(y)
    optimizer = tf.train.AdamOptimizer(lr)
    accuracy = network.accuracy(y)
    trainOp = network.training(cost, optimizer)

    # tensor board
    with tf.variable_scope('train'):
        drawTrainAcc = tf.summary.scalar('acc', accuracy)
        drawTrainLoss = tf.summary.scalar('loss', cost)
        trainSummaryOp = tf.summary.merge([drawTrainAcc, drawTrainLoss])
    with tf.variable_scope("validation"):
        drawValAcc = tf.summary.scalar('acc', accuracy)
        drawValLoss = tf.summary.scalar('loss', cost)
        valSummaryOp = tf.summary.merge([drawValAcc, drawValLoss])
    
    drawImage = tf.summary.image('train_images', tf.reshape(x, imgShape), showImgCount)
    allVariables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    summaryOp = tf.summary.merge([drawImage])
    
    saver = tf.train.Saver()
    sess = tf.Session()
    summaryWriter = tf.summary.FileWriter(logDir, graph=sess.graph)

    sess.run(tf.global_variables_initializer())

    # Training cycle
    for e in range(ep):
        # train
        sess.run(trainIteratorInitOp)
        for i in range(trainIteration):
            _, batchX, batchY = sess.run(trainNextBatch)
            sess.run(trainOp, feed_dict={x: batchX, y: batchY})
        # Display logs per epoch step
        # TODO: coculate avg
        trainSummary = sess.run([trainSummaryOp,summaryOp], feed_dict={x: batchX, y: batchY})
        for summary in trainSummary:
            summaryWriter.add_summary(summary, e)
        # validation
        sess.run(testIteratorInitOp)
        for i in range(testIteration):
            _, testBatchX, testBatchY = sess.run(testNextBatch)
            sess.run([cost, accuracy], feed_dict={x: testBatchX, y: testBatchY})
        # TODO: coculate avg
        valSummary = sess.run([valSummaryOp], feed_dict={x: testBatchX, y: testBatchY})
        for summary in valSummary:
            summaryWriter.add_summary(summary, e)
            
        saver.save(sess, checkPointDir, global_step=e)
        trainAcc = sess.run(accuracy, feed_dict={x: batchX, y: batchY})
        valAcc = sess.run(accuracy, feed_dict={x: testBatchX, y: testBatchY})
        print('Epoch: {} | Train Accuracy: {} | Validation Accuracy: {}'.format(e, trainAcc, valAcc))
    summaryWriter.close()
    sess.close()