<a href="https://colab.research.google.com/github/dseuss/tf20-unet/blob/master/notebooks/Colorizer%20Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!nvidia-smi

In [0]:
!pip install tensorflow-gpu==2.0.0beta1 tensorflow-datasets pillow

In [0]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [0]:
!rm -rf logs
!rm -rf tf20-unet
!git clone https://github.com/dseuss/tf20-unet

In [0]:
import os
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
import sys
sys.path.append('tf20-unet')

In [0]:
import functools as ft
import tensorflow as tf
import tensorflow.keras as k
import tensorflow_datasets as tfds
import datetime
import numpy as np
from PIL import Image
from IPython.display import display
from tqdm import tqdm_notebook as tqdm

from models import Unet

In [0]:
def to_grayscale(image):
    x = tf.image.rgb_to_grayscale(image)
    return {'x': x / 255, 'y': image / 255}

def show_img(tensor):
    img = tensor.numpy().astype(np.uint8).squeeze()
    display(Image.fromarray(img))

def build_dataset(split):
    data = tfds.load('voc2007', split=split)
    data = data.map(lambda s: s['image'])
    data = data.map(ft.partial(tf.image.resize_with_pad, target_height=320, target_width=320))
    data = data.map(to_grayscale)
    return data

def to_tuple(data):
    return data['x'], data['y']

ds_test = build_dataset(tfds.Split.TEST)
for features in ds_test.take(1):
    pass
show_img(255 * features['x'])
show_img(255 * features['y'])
ds_test = ds_test.batch(8).prefetch(10).map(to_tuple)

ds_train = build_dataset(tfds.Split.TRAIN)
ds_train = ds_train.shuffle(128).batch(8).prefetch(10).map(to_tuple)

In [0]:
class VisualizeImages(tf.keras.callbacks.Callback):
    def __init__(self, model, example_images, log_dir):
        super().__init__()
        self.summary_writer = tf.summary.create_file_writer(log_dir)
        self.model = model
        
        self.imgs_gray, imgs_rgb = examples
        with self.summary_writer.as_default():
            tf.summary.image('groundtruth', imgs_rgb, step=0)
            
    def on_epoch_end(self, epoch, logs):
        imgs_rgb_pred = self.model(self.imgs_gray)
        with self.summary_writer.as_default():
            tf.summary.image('prediction', imgs_rgb_pred, step=epoch)

In [0]:
for examples in ds_test.take(1):
    pass

model = Unet(output_channels=3, num_filters=[64, 128, 256])
model.compile(
    loss=k.losses.MeanSquaredError(),
    optimizer=k.optimizers.Adam())

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f'logs/{current_time}/'

model.fit(
    ds_train, 
    validation_data=ds_test,
    epochs=10, 
    callbacks=[
        k.callbacks.TensorBoard(log_dir),
        VisualizeImages(model, examples, log_dir),
        k.callbacks.ModelCheckpoint(log_dir + '/ckpts', monitor='val_loss', 
                                    verbose=1, save_best_only=True, mode='max')
    ]
)

In [0]:
%tensorboard --logdir 'logs/'