<a href="https://colab.research.google.com/github/justadudewhohacks/ipynbs/blob/master/age_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies


In [0]:
!pip install -U -q PyDrive

## Download Data

In [0]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

train_data_json_id = '1CDMRQdAhcws_g1yDw_29ZD5DNDDyi7Xw'
test_data_json_id = '1_0dpT5HRTWocnK35KLQFDHzJiwV2-IQZ'

utk_images_7z_id = '1c61PoUhIPKeoRzB0XDI23XMDyJaCfKSh'
utk_landmarks_7z_id = '1Nxg7KKfEkDBWCqhusE1S6Edp6n3tTOuN'

appareal_labels_json_id = '1_zfGunGuqyrftDJIEKw6NVJOS55vyOrh'
appareal_images_7z_id = '1BDm6r88XLwDFsqOa2ZbbUtW1HDyHo5yA'
appareal_landmarks_7z_id = '1Am36Tk-BnjfV1d8_iUpRcW-cPfQtAN0H'

wiki_labels_json_id = '1BamAqN3tNEMh6kNQQ4C8nWf6gOA2IS6X'
wiki_images_7z_id = '1Fy3pi-Pra1IsN9HDD268nRvXa1TbsryE'
wiki_landmarks_7z_id = '1M-YeSGEEboVqNK8pTCJhbxeVaLp0TKJ4'

if not os.path.exists('./data'):
  os.makedirs('./data')
if not os.path.exists('./data/utk'):
  os.makedirs('./data/utk')
if not os.path.exists('./data/appareal'):
  os.makedirs('./data/appareal')
if not os.path.exists('./data/wiki'):
  os.makedirs('./data/wiki')

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
    
print('downloading trainData.json and testData.json ...')
drive.CreateFile({ 'id': train_data_json_id }).GetContentFile('./data/trainData.json')
drive.CreateFile({ 'id': test_data_json_id }).GetContentFile('./data/testData.json')

print('downloading utk data ...')
drive.CreateFile({ 'id': utk_images_7z_id }).GetContentFile('./data/utk/images.7z')
drive.CreateFile({ 'id': utk_landmarks_7z_id }).GetContentFile('./data/utk/landmarks.7z')

print('downloading appareal data ...')
drive.CreateFile({ 'id': appareal_labels_json_id }).GetContentFile('./data/appareal/labels.json')
drive.CreateFile({ 'id': appareal_images_7z_id }).GetContentFile('./data/appareal/images.7z')
drive.CreateFile({ 'id': appareal_landmarks_7z_id }).GetContentFile('./data/appareal/landmarks.7z')

print('downloading wiki data ...')
drive.CreateFile({ 'id': wiki_labels_json_id }).GetContentFile('./data/wiki/labels.json')
drive.CreateFile({ 'id': wiki_images_7z_id }).GetContentFile('./data/wiki/images.7z')
drive.CreateFile({ 'id': wiki_landmarks_7z_id }).GetContentFile('./data/wiki/landmarks.7z')
  
print('done!')

!rm -rf ./sample_data
!cd ./data/utk && p7zip -d ./images.7z >> ../../utk-images.unzip.txt
!cd ./data/utk && p7zip -d ./landmarks.7z >> ../../utk-landmarks.unzip.txt
!cd ./data/appareal && p7zip -d ./images.7z >> ../../appareal-images.unzip.txt
!cd ./data/appareal && p7zip -d ./landmarks.7z >> ../../appareal-landmarks.unzip.txt
!cd ./data/wiki && p7zip -d ./images.7z >> ../../wiki-images.unzip.txt
!cd ./data/wiki && p7zip -d ./landmarks.7z >> ../../wiki-landmarks.unzip.txt
!rm -rf *.unzip.txt

## Training

### Imports

In [0]:
import cv2
import math
import json
import random
import time
import os
import numpy as np
import google.colab as colab
import tensorflow as tf

### Preprocessing

In [0]:
def num_in_range(val, min_val, max_val):
  return min(max(min_val, val), max_val)

def random_crop(img, landmarks):
  height, width, _ = img.shape
  min_x, min_y, max_x, max_y = width, height, 0, 0
  for pt in landmarks:
    min_x = pt['x'] if pt['x'] < min_x else min_x
    min_y = pt['y'] if pt['y'] < min_y else min_y
    max_x = max_x if pt['x'] < max_x else pt['x']
    max_y = max_y if pt['y'] < max_y else pt['y']
  
  min_x = int(num_in_range(min_x, 0, 1) * width)
  min_y = int(num_in_range(min_y, 0, 1) * height)
  max_x = int(num_in_range(max_x, 0, 1) * width)
  max_y = int(num_in_range(max_y, 0, 1) * height)
  x0 = random.randint(0, min_x)
  y0 = random.randint(0, min_y)
  x1 = random.randint(0, abs(width - max_x)) + max_x
  y1 = random.randint(0, abs(height - max_y)) + max_y

  return img[y0:y1, x0:x1]

def resize_preserve_aspect_ratio(img, size):
  height, width, _ = img.shape
  max_dim = max(height, width)
  ratio = size / float(max_dim)
  shape = (height * ratio, width * ratio)
  resized_img = cv2.resize(img, (int(round(height * ratio)), int(round(width * ratio))))
  
  return resized_img
  
def pad_to_square(img):
  height, width, channels = img.shape
  max_dim = max(height, width)
  square_img = np.zeros([max_dim, max_dim, channels])

  dx = math.floor(abs(max_dim - width) / 2)
  dy = math.floor(abs(max_dim - height) / 2)
  square_img[dy:dy + height,dx:dx + width] = img

  return square_img

def preprocess(img, size, landmarks = None, with_random_crop = True):
  cropped_img = random_crop(img, landmarks) if with_random_crop else cropped_img
  resized_img = resize_preserve_aspect_ratio(cropped_img, size)
  square_img = pad_to_square(resized_img)
  
  return square_img

### Weight Serialization

In [0]:
class WeightProcessor:
  def __init__(self, process_weights, processor_bias):
    self.process_weights = process_weights
    self.processor_bias = processor_bias
  
  def process_conv_weights(self, channels_in, channels_out, prefix, filter_size = 3):
    self.process_weights([filter_size, filter_size, channels_in, channels_out], prefix + '/filter')
    self.processor_bias([channels_out], prefix + '/bias')

  def process_depthwise_separable_conv2d_weights(self, channels_in, channels_out, prefix):
    self.process_weights([3, 3, channels_in, 1], prefix + '/depthwise_filter'),
    self.process_weights([1, 1, channels_in, channels_out], prefix + '/pointwise_filter'),
    self.processor_bias([channels_out], prefix + '/bias')

  def process_dense_block_weights(self, channels_in, channels_out, prefix, is_first_layer = False):
    conv0_processor = self.process_conv_weights if is_first_layer else self.process_depthwise_separable_conv2d_weights
    conv0_processor(channels_in, channels_out, prefix + '/conv0')
    self.process_depthwise_separable_conv2d_weights(channels_out, channels_out, prefix + '/conv1')
    self.process_depthwise_separable_conv2d_weights(channels_out, channels_out, prefix + '/conv2')
    self.process_depthwise_separable_conv2d_weights(channels_out, channels_out, prefix + '/conv3')

  def process_bottleneck_weights(self, channels_in, channels_out, expansion_factor, prefix):
    channels_expand = channels_in * expansion_factor
    self.process_conv_weights(channels_in, channels_expand, prefix + '/expansion_conv', filter_size = 1)
    self.process_depthwise_separable_conv2d_weights(channels_expand, channels_out, prefix + '/separable_conv')

class WeightMap:
  def __init__(self, tensors, tensor_paths):
    self.weights = {}
    for idx, tensor in enumerate(tensors):
      tensor_path = tensor_paths[idx]

      tmp = self.weights
      keys = tensor_path.split('/')
      for path_idx, key in enumerate(keys):
        is_end = path_idx == len(keys) - 1
        tmp[key] = tensor if is_end else (tmp[key] if key in tmp else {})
        tmp = tmp[key]
  
  def get_tensor_from_path(self, tensor_path):
    tmp = self.weights
    for key in tensor_path.split('/'):
      tmp = tmp[key]
    return tmp
  
  def set_tensor_from_path(self, tensor_path, tensor):
    tmp = self.weights
    path = tensor_path.split('/')
    for key in path[0: len(path) - 1]:
      tmp = tmp[key]
    tmp[path[len(path) - 1]] = tensor
    
def init_trainable_weights(net, weight_initializer = tf.keras.initializers.glorot_normal(), bias_initializer = tf.zeros):
  tensors = []
  tensor_paths = []
  def process_weights(shape, tensor_path):
    tensors.append(tf.Variable(weight_initializer(shape)))
    tensor_paths.append(tensor_path)
  def process_bias(shape, tensor_path):
    tensors.append(tf.Variable(bias_initializer(shape)))
    tensor_paths.append(tensor_path)

  net.process_weights(process_weights, process_bias)

  return WeightMap(tensors, tensor_paths)
  
def load_weights(net, checkpoint_file):  
  checkpoint_data = np.load(checkpoint_file)
  
  idx = 0
  tensors = []
  tensor_paths = []
  def extract_weights_from_shape(shape, tensor_path):
    nonlocal idx
    size = 1
    for val in shape:
      size = size * val
    tensor = tf.convert_to_tensor(np.reshape(checkpoint_data[idx:idx + size], shape), dtype=tf.float32)
    
    idx += size
    tensors.append(tensor)
    tensor_paths.append(tensor_path)

  net.process_weights(extract_weights_from_shape, extract_weights_from_shape)

  return WeightMap(tensors, tensor_paths)

def save_weights(net, weight_map, checkpoint_file):  
  checkpoint_data = np.array([])
  def append_weights(shape, tensor_path):
    nonlocal checkpoint_data
    tensor_data_flat = weight_map.get_tensor_from_path(tensor_path).eval().flatten()
    checkpoint_data = np.append(checkpoint_data, tensor_data_flat)

  net.process_weights(append_weights, append_weights)
  np.save(checkpoint_file, checkpoint_data)

### Neural Network

In [0]:
def conv2d(x, weights, stride):
  out = tf.nn.conv2d(x, weights['filter'], stride, 'SAME')
  out = tf.add(out, weights['bias'])
  return out

def depthwise_separable_conv2d(x, weights, stride):
  out = tf.nn.separable_conv2d(x, weights['depthwise_filter'], weights['pointwise_filter'], stride, 'SAME')
  out = tf.add(out, weights['bias'])
  return out
  
def fully_connected(x, weights):
  out = tf.reshape(x, [-1, weights['weights'].get_shape().as_list()[0]])
  out = tf.matmul(out, weights['weights'])
  out = tf.add(out, weights['bias'])
  return out

def dense_block(x, weights, is_first_layer = False, is_scale_down = True):
  initial_stride = [1, 2, 2, 1]  if is_scale_down else [1, 1, 1, 1]
  out1 = conv2d(x, weights['conv0'], initial_stride) if is_first_layer else depthwise_separable_conv2d(x, weights['conv0'], initial_stride)
  
  in2 = tf.nn.relu(out1)
  out2 = depthwise_separable_conv2d(in2, weights['conv1'], [1, 1, 1, 1])

  in3 = tf.nn.relu(tf.add(out1, out2))
  out3 = depthwise_separable_conv2d(in3, weights['conv2'], [1, 1, 1, 1])

  in4 = tf.nn.relu(tf.add(out1, tf.add(out2, out3)))
  out4 = depthwise_separable_conv2d(in4, weights['conv3'], [1, 1, 1, 1])

  return tf.nn.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4))))

def bottleneck(x, weights, stride, is_residual = False):
  #TODO: Relu6?
  out = conv2d(x, weights['expansion_conv'], [1, 1, 1, 1])
  out = depthwise_separable_conv2d(out, weights['separable_conv'], stride)
  if is_residual:
    out = tf.add(x, out)
    
  return tf.nn.relu(out)

def normalize(x, mean_rgb):
  r, g, b = mean_rgb
  shape = np.append(np.array(x.shape[0:3]), [1])
  avg_r = tf.fill(shape, r)
  avg_g = tf.fill(shape, g)
  avg_b = tf.fill(shape, b)
  avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3)

  return tf.divide(tf.subtract(x, avg_rgb), 255)

class DenseMobilenet_4_4:
  def process_weights(self, process_weights, processor_bias):
    weight_processor = WeightProcessor(process_weights, processor_bias)
    weight_processor.process_dense_block_weights(3, 32, 'dense0', True)
    weight_processor.process_dense_block_weights(32, 64, 'dense1')
    weight_processor.process_dense_block_weights(64, 128, 'dense2')
    weight_processor.process_dense_block_weights(128, 256, 'dense3')
    weight_processor.process_weights([256, 1], 'fc_age/weights')
    weight_processor.processor_bias([1], 'fc_age/bias')
    
  def forward(self, batch_tensor, weights):
    mean_rgb = [122.782, 117.001, 104.298]
    normalized = normalize(batch_tensor, mean_rgb)

    out = dense_block(normalized, weights['dense0'], True)
    out = dense_block(out, weights['dense1'])
    out = dense_block(out, weights['dense2'])
    out = dense_block(out, weights['dense3'])
    out = tf.nn.avg_pool(out, [1, 7, 7, 1], [1, 2, 2, 1], 'VALID')
    out = fully_connected(out, weights['fc_age'])
    
    return out
  
class DenseMobilenet_4_5:
  def process_weights(self, process_weights, processor_bias):
    weight_processor = WeightProcessor(process_weights, processor_bias)
    weight_processor.process_dense_block_weights(3, 32, 'dense0', True)
    weight_processor.process_dense_block_weights(32, 64, 'dense1')
    weight_processor.process_dense_block_weights(64, 128, 'dense2')
    weight_processor.process_dense_block_weights(128, 256, 'dense3')
    weight_processor.process_dense_block_weights(256, 512, 'dense4')
    weight_processor.process_weights([512, 1], 'fc_age/weights')
    weight_processor.processor_bias([1], 'fc_age/bias')
    
  def forward(self, batch_tensor, weights):
    mean_rgb = [122.782, 117.001, 104.298]
    normalized = normalize(batch_tensor, mean_rgb)

    out = dense_block(normalized, weights['dense0'], is_first_layer = True, is_scale_down = False)
    out = dense_block(out, weights['dense1'])
    out = dense_block(out, weights['dense2'])
    out = dense_block(out, weights['dense3'])
    out = dense_block(out, weights['dense4'])
    out = tf.nn.avg_pool(out, [1, 7, 7, 1], [1, 2, 2, 1], 'VALID')
    out = fully_connected(out, weights['fc_age'])
    
    return out
  
class MobilenetV2:
  def process_weights(self, process_weights, processor_bias):
    weight_processor = WeightProcessor(process_weights, processor_bias)
    weight_processor.process_conv_weights(3, 32, 'conv_in', filter_size = 3)
    weight_processor.process_bottleneck_weights(32, 16, 1, 'bottleneck0/n0')
    weight_processor.process_bottleneck_weights(16, 24, 6, 'bottleneck1/n0')
    weight_processor.process_bottleneck_weights(24, 24, 6, 'bottleneck1/n1')
    weight_processor.process_bottleneck_weights(24, 32, 6, 'bottleneck2/n0')
    weight_processor.process_bottleneck_weights(32, 32, 6, 'bottleneck2/n1')
    weight_processor.process_bottleneck_weights(32, 32, 6, 'bottleneck2/n2')
    weight_processor.process_bottleneck_weights(32, 64, 6, 'bottleneck3/n0')
    weight_processor.process_bottleneck_weights(64, 64, 6, 'bottleneck3/n1')
    weight_processor.process_bottleneck_weights(64, 64, 6, 'bottleneck3/n2')
    weight_processor.process_bottleneck_weights(64, 64, 6, 'bottleneck3/n3')
    weight_processor.process_bottleneck_weights(64, 96, 6, 'bottleneck4/n0')
    weight_processor.process_bottleneck_weights(96, 96, 6, 'bottleneck4/n1')
    weight_processor.process_bottleneck_weights(96, 96, 6, 'bottleneck4/n2')
    weight_processor.process_bottleneck_weights(96, 160, 6, 'bottleneck5/n0')
    weight_processor.process_bottleneck_weights(160, 160, 6, 'bottleneck5/n1')
    weight_processor.process_bottleneck_weights(160, 160, 6, 'bottleneck5/n2')
    weight_processor.process_bottleneck_weights(160, 320, 6, 'bottleneck6/n0')
    weight_processor.process_conv_weights(320, 1280, 'conv_expand', filter_size = 1)
    weight_processor.process_conv_weights(1280, 1, 'conv_age_out', filter_size = 1)
    
  def forward(self, batch_tensor, weights):
    mean_rgb = [122.782, 117.001, 104.298]
    normalized = normalize(batch_tensor, mean_rgb)

    # initial stride of 1 (112x112 input) instead of 2 (224x224 input)
    out = tf.nn.relu(conv2d(normalized, weights['conv_in'], [1, 1, 1, 1]))
    out = bottleneck(out, weights['bottleneck0']['n0'], [1, 1, 1, 1])
    out = bottleneck(out, weights['bottleneck1']['n0'], [1, 2, 2, 1])
    out = bottleneck(out, weights['bottleneck1']['n1'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck2']['n0'], [1, 2, 2, 1])
    out = bottleneck(out, weights['bottleneck2']['n1'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck2']['n2'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck3']['n0'], [1, 2, 2, 1])
    out = bottleneck(out, weights['bottleneck3']['n1'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck3']['n2'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck3']['n3'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck4']['n0'], [1, 1, 1, 1])
    out = bottleneck(out, weights['bottleneck4']['n1'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck4']['n2'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck5']['n0'], [1, 2, 2, 1])
    out = bottleneck(out, weights['bottleneck5']['n1'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck5']['n2'], [1, 1, 1, 1], True)
    out = bottleneck(out, weights['bottleneck6']['n0'], [1, 1, 1, 1])
    out = tf.nn.relu(conv2d(out, weights['conv_expand'], [1, 1, 1, 1]))
    out = tf.nn.avg_pool(out, [1, 7, 7, 1], [1, 2, 2, 1], 'VALID')
    out = tf.nn.relu(conv2d(out, weights['conv_age_out'], [1, 1, 1, 1]))

    out = tf.reshape(out, [out.shape[0], out.shape[3]])
    
    return out
  



### Data Loader

In [0]:
def load_json(json_file_path):
  with open(json_file_path) as json_file:  
    return json.load(json_file)

def load_image(data, with_random_crop = True):
  db = data['db']
  img_file = data['file']
  file_suffix = 'chip_0' if db == 'utk' else ('face_0' if db == 'appareal' else '')
  landmarks_file = img_file.replace(file_suffix + '.jpg', file_suffix + '.json')
  img_file_path = './data/' + db + '/cropped-images/' + img_file
  landmarks_file_path = './data/' + db + '/landmarks/' + landmarks_file

  img = cv2.imread(img_file_path)
  if img is None:
    raise 'failed to read image from path: ' + img_file_path

  landmarks = load_json(landmarks_file_path) if with_random_crop else None
  preprocessed_img = preprocess(img, 112, landmarks, with_random_crop)
  
  return preprocessed_img
    
def load_image_batch(datas, with_random_crop = True):
  preprocessed_imgs = []
  for data in datas:
    preprocessed_imgs.append(load_image(data, with_random_crop))
  return np.stack(preprocessed_imgs, axis=0)

def shuffle_array(arr):
  arr_clone = arr[:]
  random.shuffle(arr_clone)
  return arr_clone

class LabelExtractor:
  def __init__(self):
    self.appareal_labels = load_json('./data/appareal/labels.json')
    self.wiki_labels = load_json('./data/wiki/labels.json')
    
  def extract_labels(self, data):
    db = data['db']
    img_file = data['file']

    if db == 'utk':
      age = int(float(img_file.split('_')[0]))
      return age
    elif db == 'appareal':
      age = self.appareal_labels[img_file]['age']
      return age
    elif db == 'wiki':
      age = self.wiki_labels[img_file]['age']
      return age
    else: raise('unknown db: ' + db)
  
class DataLoader:
  def __init__(self, data_json, start_epoch = None, is_test = False):
    if not is_test and start_epoch == None:
      raise 'DataLoader - start_epoch has to be defined in train mode'
    
    self.label_extractor = LabelExtractor()
    self.is_test = is_test
    self.data = data_json
    self.buffered_data = shuffle_array(self.data) if not is_test else self.data
    self.current_idx = 0
    self.epoch = start_epoch
 
  def get_end_idx(self):
    return len(self.buffered_data)
    
  def extract_all_labels(self, datas):
    labels = []
    for data in datas:
      labels.append(self.label_extractor.extract_labels(data))
    return np.expand_dims(np.stack(labels, axis = 0), axis = 1)
    
  def next_batch(self, batch_size):
    if batch_size < 1:
      raise 'DataLoader.next_batch - invalid batch_size: ' + str(batch_size)
      
    
    from_idx = self.current_idx
    to_idx = self.current_idx + batch_size
    
    # end of epoch
    if (to_idx > len(self.buffered_data)):
      if self.is_test:
        to_idx = len(self.buffered_data)
        if to_idx == self.current_idx:
          return None
      else:
        self.epoch += 1
        self.buffered_data = self.buffered_data[from_idx:] + shuffle_array(self.data)  
        from_idx = 0
        to_idx = batch_size
      
    self.current_idx = to_idx
    
    next_data = self.buffered_data[from_idx:to_idx]
      
    batch_x = load_image_batch(next_data)
    batch_y = self.extract_all_labels(next_data)
    
    return batch_x, batch_y

### Training

In [0]:
#tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.log_device_placement = True

tf.reset_default_graph()

net = MobilenetV2()
model_name = './mobilenetv2'

def get_checkpoint(epoch):
  return model_name + '.ckpt-' + str(epoch)
  
# training parameters
learning_rate = 0.001
start_epoch = 0
end_epoch = 2000
batch_size = 16

#optimizer = tf.train.GradientDescentOptimizer(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

X = tf.placeholder(tf.float32, [batch_size, 112, 112, 3])
Y = tf.placeholder(tf.float32, [batch_size, 1])

train_data = load_json('./data/trainData.json')
data_loader = DataLoader(train_data, start_epoch = start_epoch)
weight_map = init_trainable_weights(net)

age = net.forward(X, weight_map.weights)
loss_op = tf.reduce_mean(tf.abs(tf.subtract(age, Y)))
train_op = optimizer.minimize(loss_op)

log_file = open('./log.txt', 'w')

init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep = None)
  
total_loss = 0
iteration_count = 0
ts_epoch = time.time()
#with tf.Session(tpu_address) as sess:
with tf.Session(config = config) as sess:
  sess.run(init)

  if (start_epoch != 0):
    checkpoint = get_checkpoint(start_epoch - 1)
    saver.restore(sess, checkpoint)
    print('done restoring session')
    
  with tf.device('/gpu:0'):
      
    while data_loader.epoch <= end_epoch:
      epoch = data_loader.epoch
      current_idx = data_loader.current_idx
      end_idx = data_loader.get_end_idx()

      ts = time.time()

      batch_x, batch_y = data_loader.next_batch(batch_size)
      
      loss, _ = sess.run([loss_op, train_op], feed_dict = { X: batch_x, Y: batch_y })
      total_loss += loss
      iteration_count += 1
      log_file.write("epoch " + str(epoch) + ", (" + str(current_idx) + " of " + str(end_idx) + "), loss= " + "{:.4f}".format(loss) 
            + ", time= " + str((time.time() - ts) * 1000) + "ms \n")

      if epoch != data_loader.epoch:
        print('next epoch: ' + str(data_loader.epoch))
        saver.save(sess, model_name + '.ckpt', global_step = epoch)
        
        epoch_txt_file_path = 'epoch_' + str(epoch) + '.txt'
        epoch_txt = open(epoch_txt_file_path, 'w')
        epoch_txt.write('total_loss= ' + str(total_loss) + '\n')
        epoch_txt.write('avg_loss= ' + str(total_loss / iteration_count) + '\n')
        epoch_txt.write('learning_rate= ' + str(learning_rate) + '\n')
        epoch_txt.write('batch_size= ' + str(batch_size) + '\n')
        epoch_txt.write('epoch_time= ' + str(time.time() - ts_epoch) + 's \n')
        epoch_txt.close()
        
        #colab.files.download(epoch_txt_file_path)
        #colab.files.download(get_checkpoint(epoch) + '.index') 
        #colab.files.download(get_checkpoint(epoch) + '.meta') 
        #colab.files.download(get_checkpoint(epoch) + '.data-00000-of-00001')

        total_loss = 0
        iteration_count = 0              
        ts_epoch = time.time()

    print('done!')
    log_file.close() 

### Testing

In [0]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.log_device_placement = True

net = DenseMobilenet_4_4()
model_name = './dense_mobilenet_4_4'

def get_checkpoint(epoch):
  return model_name + '.ckpt-' + str(epoch)

def compile_loss_op(X, Y, weight_map):
  age = net.forward(X, weight_map.weights)
  loss_op = tf.reduce_sum(tf.abs(tf.subtract(age, Y)))
  return loss_op

batch_size = 32
dbs = ['utk', 'wiki', 'appareal']
test_data = load_json('./data/testData.json')

for epoch in range(33, 120):
  tf.reset_default_graph()
  weight_map = init_trainable_weights(net)

  X = tf.placeholder(tf.float32, [batch_size, 112, 112, 3])
  Y = tf.placeholder(tf.float32, [batch_size, 1])
  loss_op = compile_loss_op(X, Y, weight_map)

  test_txt = open('test_epoch_' + str(epoch) + '.txt', 'w')

  init = tf.global_variables_initializer()
  saver = tf.train.Saver(max_to_keep = None)

  total_loss = 0
  iteration_count = 0
  ts_test = time.time()
  #with tf.Session(tpu_address) as sess:
  with tf.Session(config = config) as sess:
    checkpoint = get_checkpoint(epoch)
    sess.run(init)
    saver.restore(sess, checkpoint)
    
    with tf.device('/gpu:0'):

      total_loss_db = 0
      iteration_count_db = 0
      for db in dbs:
        db_data = []
        for data in test_data:
          if data['db'] == db:
            db_data.append(data)

        data_loader = DataLoader(db_data, is_test = True)
        next_batch = data_loader.next_batch(batch_size)
        while next_batch != None:
          #print(str(db) + " : " + str(data_loader.current_idx) + " of " + str(data_loader.get_end_idx()))
          batch_x, batch_y = next_batch
          if batch_x.shape[0] != batch_size:
            X_tmp = tf.placeholder(tf.float32, [batch_x.shape[0], 112, 112, 3])
            Y_tmp = tf.placeholder(tf.float32, [batch_x.shape[0], 1])
            loss_op_tmp = compile_loss_op(X_tmp, Y_tmp, weight_map)
            loss = sess.run(loss_op_tmp, feed_dict = { X_tmp: batch_x, Y_tmp: batch_y }) / batch_x.shape[0]
          else:
            loss = sess.run(loss_op, feed_dict = { X: batch_x, Y: batch_y }) / batch_size
          total_loss += loss
          total_loss_db += loss
          iteration_count += 1
          iteration_count_db += 1
          next_batch = data_loader.next_batch(batch_size)

        test_txt.write(str(db) + ":" + '\n')
        test_txt.write('total_loss= ' + str(total_loss_db) + '\n')
        test_txt.write('avg_loss= ' + str(total_loss_db / iteration_count_db) + '\n')
        test_txt.write('\n')
        total_loss_db = 0
        iteration_count_db = 0

      test_txt.write('----------------\n\n')
      test_txt.write('total_loss= ' + str(total_loss) + '\n')
      test_txt.write('avg_loss= ' + str(total_loss / iteration_count) + '\n')
      test_txt.write('test_time= ' + str(time.time() - ts_test) + 's \n')
      test_txt.close()