# TEST

In [None]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))

from tools.network import Network
from tools.decode_raw import decode_img_seg, decode_img_seg_test

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
network = Network()

read_train_params = {
    'path': '../data/train.csv',
    'train_path': '../data/train_images',
    'height': 256,
    'width': 1600, 
    'col': False,
    'sep':'[_,]'
}

gen = network.seg_train_gen(**read_train_params)

feature_dict = {
    'img': 'bytes', 
    'label': 'bytes',
    'height': 'int', 
    'width': 'int',
    'channels': 'int',
    'n_class': 'int'
}

write_tfr_params = {
    'data_generator': gen,
    'count': network.count('../data/train_images'),
    'tfrpath': 'train',
    'feature_dict': feature_dict,
    'shards': 10,
    'compression': 'GZIP', 
    'c_level': 1
}

# network.write_tfr(**write_tfr_params)

In [None]:
train_path = os.path.join('..', 'data', 'train_images')
tfr_path = os.path.join('..', 'tmp', 'TFRecords', 'train', '*.tfrecord')

rt_params = {
    'feature_dict': feature_dict, 
    'decode_raw': decode_img_seg,
    'tfr_path': tfr_path,
    'shuffle_buffer': 100,
    'compression': 'GZIP'
}
network.readtrain(rt_params, train_path, num_valid=100)

In [None]:
model_params = {
    'num_layers': 5,
    'feature_growth_rate': 16,
    'n_class': 4,
    'channels': 3,
    'padding': 'SAME',
    'dropout_rate':0.25
}

network.train(
    model_name = 'unet',
    model_params = model_params,
    loss = 'dice',
    metric = 'dice',
    optimizer = 'adam',
    rate = 0.01,
    early_stopping = True,
    epoch = 1,
    batch_size = 16,
    reshape = [64, 400], 
    reshape_method = 2, 
    print_per_epoch = True
)