In [0]:
!pip install tensorflow==2.0.0beta1 tensorflow-datasets pillow

In [0]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [0]:
# Run this cell to mount your Google Drive. Alternatively, clone the git repo
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/My Drive/Code/colab_workflow')

In [0]:
import os
import tensorflow as tf
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tf.config.experimental_connect_to_host(TPU_WORKER, 'tpu_worker')

In [0]:
import functools as ft
import tensorflow as tf
import tensorflow.keras as k
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
from IPython.display import display
from tqdm import tqdm_notebook as tqdm

from models import Unet

In [0]:
def to_grayscale(image):
    x = tf.image.rgb_to_grayscale(image)
    return {'x': x, 'y': image}

def show_img(tensor):
    img = tensor.numpy().astype(np.uint8).squeeze()
    display(Image.fromarray(img))

def build_dataset(split):
    data = tfds.load('voc2007', split=split)
    data = data.map(lambda s: s['image'])
    data = data.map(ft.partial(tf.image.resize_with_pad, target_height=320, target_width=320))
    data = data.map(to_grayscale)
    return data

ds_test = build_dataset(tfds.Split.TEST)
for features in ds_test.take(1):
    pass
show_img(features['x'])
show_img(features['y'])
ds_test = ds_test.batch(8).prefetch(10)

ds_train = build_dataset(tfds.Split.TRAIN)
ds_train = ds_train.shuffle(128).batch(8).prefetch(10)

In [0]:
loss_fn = k.losses.MeanSquaredError()
optimizer = k.optimizers.Adam()

model = Unet(output_channels=3, num_filters=[64, 128, 256])
metrics = {
    'train_loss': k.metrics.Mean(name='train_loss'),
    'test_loss': k.metrics.Mean(name='train_loss'),
}

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = loss_fn(y, y_pred)
        grad = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
    
    metrics['train_loss'](loss)
    
@tf.function
def test_step(x, y):
    y_pred = model(x)
    loss = loss_fn(y, y_pred)
    metrics['test_loss'](loss)
    
    
for epoch in range(1):
    for data in tqdm(ds_train):
        train_step(**data)
        
    for data in tqdm(ds_test):
        test_step(**data)
    
    metric_strs = (f'{name}={value.result():.04f}' for name, value in metrics.items())
    print(f'[{epoch}] ' + ' '.join(metric_strs))