In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))

from fiat.arch.image_seg import unet, resunet_arch
# from fiat.components.arch import 
from fiat.os import count
from fiat.data import TFR
from fiat.train import train
from fiat.losses import focal_tversky_loss, lossfromname
from fiat.optimizer import optfromname
from tools.data import feature_dict, decode_img_seg, seg_train_gen
from fiat.DataAugment import flip_up_down, flip_left_right

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [2]:
csv_path = os.path.join('..', 'data', 'train.csv')
train_path = os.path.join('..', 'data', 'train_images')
tfr_path = os.path.join('..', 'tmp', 'TFRecords', 'train')
ckpt_path = os.path.join('..', 'tmp', 'ckpt')

In [3]:
tfr = TFR(path=tfr_path,
          count=count(train_path),
          feature_dict=feature_dict, 
          shards=30, 
          compression='GZIP',
          c_level=1)

In [4]:
arch = unet(num_layers=5, 
            feature_growth_rate=32,
            n_class=4,
            channels=3,
            padding='SAME',
            dropout_rate=0.,
            active='sigmoid')

In [None]:
"""
arch = resunet_arch(reslayer='seresnet',
                    numlayers='18',
                    numstages=4,
                    channels=64,
                    n_class=4,
                    padding='SAME',
                    rate=0.25,
                    active='sigmoid')
"""

In [5]:
read = tfr.read(decode_raw=decode_img_seg,
                split=10,
                valid=9, 
                # augment=[flip_up_down(), flip_left_right()],
                buffer_size=400,
                num_parallel_reads=6)

In [None]:
info = train(arch, read, 
             loss=focal_tversky_loss(), metric='dice', optimizer='adam',
             rate=1e-6, epoch=1, batch_size=32, early_stopping=1,
             verbose=2, retrain=False, reshape=[64, 400], reshape_method=3, ckpt_path=ckpt_path)

Instructions for updating:
Colocations handled automatically by placer.


epoch 1, train loss: 0.9763, valid metric: 0:  56%|█████▋    | 200/354 [00:47<00:32,  4.69it/s]

In [7]:
info

{'train': [-4294967296, 0.9771569177441178],
 'valid': [-4294967296, 0.014239655357907968]}

In [17]:
float(1e-2)

0.01

In [None]:
fm = 200
to = 201

rate = 1e-4
batch_size = 32
epoch = 1
tf.reset_default_graph()

tf.reset_default_graph()

with tf.name_scope('data'):
    data, traincount, validcount = read()

    for key, val in data.items():
        data[key] = val.batch(batch_size)
    dataset = data['train'].concatenate(data['valid'])

    dataset = dataset.repeat(epoch)
    iterator = dataset.make_initializable_iterator()
    img, label = iterator.get_next()
    img = tf.image.resize(img, size=[64, 400], method=2)
    
    y_pred, y_map = arch(img)
    shape = tf.shape(y_pred)
    y_true = tf.image.resize(label, [shape[1], shape[2]])
    loss = focal_tversky_loss()(y_true, y_pred)
    opt = optfromname('adam', learning_rate=rate, loss=loss)
    
#saver = tf.train.Saver(max_to_keep=epoch)
with tf.Session() as sess:
    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(iterator.initializer)
    #saver.restore(sess, os.path.join(ckpt_path, f'epoch_1', 'model.ckpt'))
    
    for i in range(to):
        img_arr, y_pred_arr, y_map_arr, loss_arr = sess.run((img, y_pred, y_map, loss))
        if i >= fm:
            print(y_map_arr)
            plt.figure(figsize=[20, 20])
            plt.subplot(9, 1, 1)
            plt.imshow(img_arr[0].astype('int'))
            for j in range(4):
                plt.subplot(9, 1, j+2)
                plt.imshow(y_pred_arr[0,:,:,j])
                plt.subplot(9, 1, j+6)
                plt.imshow(y_map_arr[0,:,:,j])

In [None]:
fm = 1000
to = 1001

rate = 1e-8
batch_size = 32
epoch = 1001 * batch_size // 5000
tf.reset_default_graph()

tf.reset_default_graph()

with tf.name_scope('data'):
    data, traincount, validcount = read()

    for key, val in data.items():
        data[key] = val.batch(batch_size)
    dataset = data['train'].concatenate(data['valid'])

    dataset = dataset.repeat(epoch)
    iterator = dataset.make_initializable_iterator()
    img, label = iterator.get_next()
    img = tf.image.resize(img, size=[64, 400], method=2)
    
    y_pred, y_map = arch(img)
    shape = tf.shape(y_pred)
    y_true = tf.image.resize(label, [shape[1], shape[2]])
    loss = focal_tversky_loss()(y_true, y_pred)   #
    opt = optfromname('gd', learning_rate=rate, loss=loss)
    
#saver = tf.train.Saver(max_to_keep=epoch)
with tf.Session() as sess:
    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(iterator.initializer)
    #saver.restore(sess, os.path.join(ckpt_path, f'epoch_1', 'model.ckpt'))
    
    for i in range(to):
        img_arr, y_pred_arr, y_map_arr, loss_arr, _ = sess.run((img, y_pred, y_map, loss, opt))
        if i >= fm:
            print(y_map_arr)
            plt.figure(figsize=[20, 20])
            plt.subplot(9, 1, 1)
            plt.imshow(img_arr[0].astype('int'))
            for j in range(4):
                plt.subplot(9, 1, j+2)
                plt.imshow(y_pred_arr[0,:,:,j])
                plt.subplot(9, 1, j+6)
                plt.imshow(y_map_arr[0,:,:,j])

In [None]:
tf.reset_default_graph()
data, _, _ = read()
dataset = data['train']
dataset = dataset.batch(1)
iterator = dataset.make_initializable_iterator()
img_t, label_t = iterator.get_next()
img_t = tf.image.resize(img_t, size=[128, 800], method=2)
y, y_map = arch(img_t)
shape = tf.shape(y)
#saver = tf.train.Saver()
with tf.Session() as sess:
    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(iterator.initializer)
    #saver.restore(sess, os.path.join(ckpt_path, f'epoch_1', 'model.ckpt'))
    
    for i in range(5):
        img, label, y_mp, y_p = sess.run((img_t, label_t, y_map, y))
        plt.figure(figsize=[20, 20])
        plt.subplot(9, 1, 1)
        plt.imshow(img[0].astype('int'))
        for j in range(4):
            plt.subplot(9, 1, j+2)
            plt.imshow(y_p[0,:,:,j])
            plt.subplot(9, 1, j+6)
            plt.imshow(label[0,:,:,j])