# Train

In [None]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))

from fiat.arch.image_seg import unet, resunet_arch
# from fiat.components.arch import 
from fiat.os import count
from fiat.data import TFR
from fiat.train import train
from tools.data import feature_dict, decode_img_seg, seg_train_gen
from fiat.DataAugment import flip_up_down, flip_left_right

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
csv_path = os.path.join('..', 'data', 'train.csv')
train_path = os.path.join('..', 'data', 'train_images')
tfr_path = os.path.join('..', 'tmp', 'TFRecords', 'train')
ckpt_path = os.path.join('..', 'tmp', 'ckpt')

In [None]:
tfr = TFR(path=tfr_path,
          count=count(train_path),
          feature_dict=feature_dict, 
          shards=30, 
          compression='GZIP',
          c_level=1)

In [None]:
#gen = seg_train_gen(csv_path, train_path, sep='[_,]', nclass=5)
#tfr.write(gen)

In [None]:
arch = unet(num_layers=4, 
            feature_growth_rate=64,
            n_class=4,
            channels=3,
            padding='SAME',
            dropout_rate=0.,
            active='sigmoid')


In [None]:
arch = resunet_arch(reslayer='resnext',
                    numlayers='101',
                    numstages=4,
                    channels=64,
                    n_class=4,
                    padding='SAME',
                    rate=0.25,
                    active='sigmoid')

In [None]:
read = tfr.read(decode_raw=decode_img_seg,
                split=10,
                valid=0, 
                # augment=[flip_up_down(), flip_left_right()],
                buffer_size=400,
                num_parallel_reads=6)


In [None]:
tf.reset_default_graph()
data, _, _ = read()
dataset = data['train']
dataset = dataset.batch(1)
iterator = dataset.make_initializable_iterator()
img_t, label_t = iterator.get_next()
y, y_map = arch(img_t)
saver = tf.train.Saver()
with tf.Session() as sess:
    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(iterator.initializer)
    saver.restore(sess, os.path.join(ckpt_path, f'epoch_3', 'model.ckpt'))
    
    for i in range(5):
        img, label, y_p = sess.run((img_t, label_t, y_map))
        plt.figure(figsize=[20, 20])
        plt.subplot(9, 1, 1)
        plt.imshow(img[0].astype('int'))
        for j in range(4):
            plt.subplot(9, 1, j+2)
            plt.imshow(y_p[0,:,:,j].astype('int'))
            plt.subplot(9, 1, j+6)
            plt.imshow(label[0,:,:,j].astype('int'))

In [None]:
info = train(arch, read, 
             loss='dice', metric='mean_dice', optimizer='adam',
             rate=2e-5, epoch=1, batch_size=32, early_stopping=1,
             verbose=2, retrain=False, reshape=[128, 800], reshape_method=3, ckpt_path=ckpt_path)

In [None]:
info2 = train(arch, read, 
             loss='bce', metric='mean_dice', optimizer='adam',
             rate=2e-5, epoch=5, batch_size=32, early_stopping=1,
             verbose=2, retrain=info, reshape=[128, 800], reshape_method=3, ckpt_path=ckpt_path)

In [None]:
arch = unet(num_layers=4, 
            feature_growth_rate=16,
            n_class=4,
            channels=3,
            padding='SAME',
            dropout_rate=0.25,
            active='sigmoid')

read = tfr.read(decode_raw=decode_img_seg,
         split=10,
         valid=0,
         buffer_size=100,
         num_parallel_reads=2)

info2 = train(arch, read, 
             loss='dice', metric='dice', optimizer='momentun',
             rate=1e-6, epoch=1, batch_size=1, early_stopping=1,
             verbose=2, retrain=info, reshape=[32, 200], reshape_method=3,
              ckpt_path=ckpt_path, distrainable=['3', '4'])

In [None]:
tf.all_variables()

In [None]:
[i.name.split('/')[0].split('_')[-1] for i in tf.all_variables()]