# Train

In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))

from tools.network import Network
from tools.decode_raw import decode_img_seg, decode_img_seg_test

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
csv_path = os.path.join('..', 'data', 'train.csv')
train_path = os.path.join('..', 'data', 'train_images')
tfr_path = os.path.join('..', 'tmp', 'TFRecords', 'train')

feature_dict = {
    'img': 'bytes', 
    'label': 'bytes',
    'height': 'int', 
    'width': 'int',
    'channels': 'int',
    'n_class': 'int'
}

network = Network()

In [3]:
rt_params = {
    'feature_dict': feature_dict, 
    'decode_raw': decode_img_seg,
    'tfr_path': os.path.join(tfr_path, '*.tfrecord'),
    'shuffle_buffer': 100,
    'compression': 'GZIP',
    'num_parallel_reads':4
}

readtrain_params = {
    'rt_params': rt_params, 
    'train_path': train_path,
    'epoch': 1,
    'batch_size': 32,
    'reshape': [32, 200],
    'reshape_method': 3}


network.readtrain(**readtrain_params)

W0927 11:12:49.035653 139789207226176 deprecation.py:323] From /home/leechh/anaconda3/envs/severstal/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0927 11:12:49.131274 139789207226176 deprecation.py:323] From /home/leechh/code/Severstal/tools/data_tfr.py:143: DatasetV1.make_initializable_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_initializable_iterator(dataset)`.


In [4]:
model_params = {
    'num_layers': 3,
    'feature_growth_rate': 16,
    'n_class': 5,
    'channels': 3,
    'padding': 'VALID',
    'dropout_rate':0.25
}

network.model(
    model_name='unet',
    model_params=model_params,
    loss='cross_entropy', 
    metric='neg_dice',
    optimizer='momentun',
    rate=1e-6)

W0927 11:12:49.607770 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/model_component.py:23: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

W0927 11:12:49.654023 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/model_component.py:62: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0927 11:12:49.951290 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/model_component.py:123: The name tf.losses.softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.softmax_cross_entropy instead.

W0927 11:12:50.000592 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/model_component.py:99: The name tf.train.MomentumOptimizer is deprecated. Please use tf.compat.v1.train.MomentumOptimizer instead.



In [5]:
network.train(ckpt_dir=os.path.join('..', 'tmp', 'ckpt'),
              train_percentage = 0.8, 
              early_stopping = 1,
              verbose = 2,
              retrain=False)

W0927 11:12:52.947725 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/network.py:128: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.

W0927 11:12:53.003290 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/network.py:129: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

W0927 11:12:53.126159 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/network.py:140: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.

W0927 11:12:53.129343 139789207226176 deprecation_wrapper.py:119] From /home/leechh/code/Severstal/tools/network.py:140: The name tf.local_variables_initializer is deprecated. Please use tf.compat.v1.local_variables_initializer instead.

epoch 1, train cross_entropy: 1.6623, valid neg_dice: 0.8004: 100%|██████████| 629/629 [03:28<00:00,  7.88s/it]
epoch 

best epoch is 2,  train score is 1.5742340228136846, valid score is 0.7852048294993634



