# TEST

In [4]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))

from tools.network import Network
from tools.decode_raw import decode_img_seg, decode_img_seg_test

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
csv_path = os.path.join('..', 'data', 'train.csv')
train_path = os.path.join('..', 'data', 'train_images')
tfr_path = os.path.join('..', 'tmp', 'TFRecords', 'train')

feature_dict = {
    'img': 'bytes', 
    'label': 'bytes',
    'height': 'int', 
    'width': 'int',
    'channels': 'int',
    'n_class': 'int'
}

network = Network()

In [None]:
seg_train_gen_params = {
    'csv_path': csv_path,
    'train_path': train_path,
    'height': 256,
    'width': 1600, 
    'col': False,
    'sep':'[_,]',
    'n_class': 5
}

gen = network.seg_train_gen(**seg_train_gen_params)

write_tfr_params = {
    'data_generator': gen,
    'train_path': train_path,
    'tfrpath': tfr_path,
    'feature_dict': feature_dict,
    'shards': 10,
    'compression': 'GZIP', 
    'c_level': 1
}

network.write_tfr(**write_tfr_params)

In [None]:
rt_params = {
    'feature_dict': feature_dict, 
    'decode_raw': decode_img_seg,
    'tfr_path': os.path.join(tfr_path, '*.tfrecord'),
    'shuffle_buffer': 100,
    'compression': 'GZIP'
}

readtrain_params = {
    'rt_params': rt_params, 
    'train_path': train_path,
    'epoch': 1,
    'batch_size': 4,
    'reshape': [32, 200],
    'reshape_method': 3}


network.readtrain(**readtrain_params)

In [None]:
model_params = {
    'num_layers': 3,
    'feature_growth_rate': 16,
    'n_class': 5,
    'channels': 3,
    'padding': 'SAME',
    'dropout_rate':0.25
}

network.model(
    model_name='unet',
    model_params=model_params,
    loss='neg_dice', 
    metric='dice',
    optimizer='momentun',
    rate=1e-5)

In [None]:
network.train(ckpt_dir=os.path.join('..', 'tmp', 'ckpt'),
              train_percentage = 0.8, 
              early_stopping = 5,
              verbose = 2,
              retrain=False)

In [None]:
network.test()