<a href="https://colab.research.google.com/github/bijmuj/SketchColorization/blob/master/sketch_colorization_AE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Fri Jun 18 14:45:26 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install tensorflow-addons
!pip install wandb
!pip install kaggle
!mkdir ~/.kaggle
!chmod 600 kaggle.json
!cp kaggle.json ~/.kaggle
!kaggle datasets download -d ktaebum/anime-sketch-colorization-pair
!unzip -q /content/anime-sketch-colorization-pair.zip -x "data/data/*" "data/colorgram/*"

Collecting tensorflow-addons
[?25l  Downloading https://files.pythonhosted.org/packages/66/4b/e893d194e626c24b3df2253066aa418f46a432fdb68250cde14bf9bb0700/tensorflow_addons-0.13.0-cp37-cp37m-manylinux2010_x86_64.whl (679kB)
[K     |▌                               | 10kB 25.0MB/s eta 0:00:01[K     |█                               | 20kB 18.3MB/s eta 0:00:01[K     |█▌                              | 30kB 15.3MB/s eta 0:00:01[K     |██                              | 40kB 14.1MB/s eta 0:00:01[K     |██▍                             | 51kB 7.5MB/s eta 0:00:01[K     |███                             | 61kB 7.3MB/s eta 0:00:01[K     |███▍                            | 71kB 8.1MB/s eta 0:00:01[K     |███▉                            | 81kB 8.8MB/s eta 0:00:01[K     |████▍                           | 92kB 9.2MB/s eta 0:00:01[K     |████▉                           | 102kB 7.4MB/s eta 0:00:01[K     |█████▎                          | 112kB 7.4MB/s eta 0:00:01[K     |█████▉     

In [None]:
import os
import glob
import time
import wandb
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Input,Conv2D, Conv2DTranspose, BatchNormalization, Dropout, LeakyReLU
from tensorflow.keras.models import Model
from tqdm import tqdm
from PIL import Image as im

In [None]:
BATCH_SIZE = 28
BUFFER_SIZE = 800
EPOCHS = 20
LR = 1e-4
beta_1 = 0.5
beta_2 = 0.999
initializer = tf.random_normal_initializer(0, 0.02)

In [None]:
wandb.login()
run = wandb.init(project="sketch_colorization", entity="bijin")
config = wandb.config
config.opt = "adam"
config.lr = LR
config.beta_1 = 0.5
config.beta_2 = 0.999
config.batch_size = 16

[34m[1mwandb[0m: Currently logged in as: [33mbijin[0m (use `wandb login --relogin` to force relogin)


In [None]:
n_batches = 14224/BATCH_SIZE
n_batches

508.0

In [None]:
def load_image(path):
    image = tf.io.read_file(path)
    image = tf.io.decode_png(image, channels=3)
    # cast and normalize
    image = tf.cast(image, tf.float32)/127.5 - 1
    w = tf.shape(image)[1]
    w = w // 2
    # split
    label = image[:, :w, :]
    sketch = image[:, w:, :]
    # resize
    # not doing this causes issues down the road
    label = tf.image.resize(label, (256, 256))
    sketch = tf.image.resize(sketch, (256, 256))
    return label, sketch

In [None]:
dataset_path = glob.glob('/content/data/train/*.png')
dataset = tf.data.Dataset.from_tensor_slices(dataset_path)

In [None]:
dataset = dataset.map(load_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)
dataset

<ParallelMapDataset shapes: ((256, 256, 3), (256, 256, 3)), types: (tf.float32, tf.float32)>

In [None]:
train_dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
dataset_path = glob.glob('/content/data/val/*.png')
val_dataset = tf.data.Dataset.from_tensor_slices(dataset_path)
val_dataset = val_dataset.map(load_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)
val_dataset = val_dataset.batch(5)

In [None]:
def encoder_unit(input, filters, norm=None, act=None):
    if norm=="Instance":
        x = tfa.layers.InstanceNormalization(axis=3, center=True, scale=True, 
                                             beta_initializer="random_uniform",
                                             gamma_initializer="random_uniform")(input)
    elif norm =="Batch":
        x = BatchNormalization()(input)
    else:
        x = input
    x = Conv2D(filters, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=initializer)(x)
    if act=="LReLU":
        x = LeakyReLU(0.2)(x)
    return x

In [None]:
def decoder_unit(input, filters, norm=None, act=None):
    if norm=="Instance":
        x = tfa.layers.InstanceNormalization(axis=3, center=True, scale=True,
                                             beta_initializer="random_uniform",
                                             gamma_initializer="random_uniform")(input)
    elif norm=="Batch":
        x = BatchNormalization()(input)
    else:
        x = input
    x = Conv2DTranspose(filters, kernel_size=4, strides=(2, 2), padding='same', kernel_initializer=initializer)(x)
    if act=="LReLU":
        x = LeakyReLU(0.2)(x)
    elif act=="tanh":
        x = tf.math.tanh(x)
    return x 

In [None]:
def autoencoder():
    input = Input((256, 256, 3))
    x = encoder_unit(input, 32, act="LReLU")
    x = encoder_unit(x, 64 , norm="Instance", act="LReLU")
    x = encoder_unit(x, 128, norm="Instance", act="LReLU")
    x = encoder_unit(x, 256, norm="Instance", act="LReLU")
    x = encoder_unit(x, 512, norm="Instance", act="LReLU")

    x = decoder_unit(x, 256, norm="Instance", act="LReLU")
    x = decoder_unit(x, 128, norm="Instance", act="LReLU")
    x = decoder_unit(x, 64, norm="Instance", act="LReLU")
    x = decoder_unit(x, 32, norm="Instance", act="LReLU")
    x = decoder_unit(x, 3, norm="Instance", act="tanh")
    return Model(inputs=[input], outputs=x)

In [None]:
opt = tf.keras.optimizers.Adam(LR, beta_1=beta_1, beta_2=beta_2)
l1_loss = tf.keras.losses.MeanAbsoluteError()
ae = autoencoder()
ae.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 32)      1568      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 128, 128, 32)      0         
_________________________________________________________________
instance_normalization (Inst (None, 128, 128, 32)      64        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 64)        32832     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 64)        0         
_________________________________________________________________
instance_normalization_1 (In (None, 64, 64, 64)        128   

In [None]:
ckpt_dir = '/content/ckpt'
ckpt = tf.train.Checkpoint(ae=ae,
                           opt=opt)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)

In [None]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_hat = ae(x)
        loss = l1_loss(y, y_hat)
    grads = tape.gradient(loss, ae.trainable_variables)
    opt.apply_gradients(zip(grads, ae.trainable_variables))
    return loss

In [None]:
def log(loss, color, sketch, output, n=BATCH_SIZE):
    images = []
    for i in range(n):
        img_row = np.hstack([color[i], sketch[i], output[i]]) * 0.5 + 0.5
        images.append(img_row)
    images = wandb.Image(np.vstack(images), caption="Left: Color, Mid: Sketch, Right: Output")
    wandb.log({"loss": loss, "examples": images})

In [None]:
def fit():
    val_iter = iter(val_dataset)
    for ep in range(EPOCHS):
        named_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", named_tuple)
        print(f'\nEpoch: {ep + 1} of {EPOCHS}\nStarted: {time_string}')

        losses = []

        for idx, (colors, sketches) in tqdm(train_dataset.enumerate(), total=n_batches):
            loss = train_step(sketches, colors)
            losses.append(loss)

        loss = np.mean(np.array(losses))
        named_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", named_tuple)
        print(f'\nEnded: {time_string}\nloss: {loss}')

        path = manager.save()
        print(f'\nCheckpoint saved in: {path}')

        colors, sketches = val_iter.next()
        outs = ae(sketches)
        log(loss, colors, sketches, outs, n=5)

In [None]:
fit()

  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 1 of 20
Started: 06/18/2021, 15:04:18


100%|██████████| 508/508.0 [03:00<00:00,  2.82it/s]



Ended: 06/18/2021, 15:07:19
loss: 0.5684372186660767

Checkpoint saved in: /content/ckpt/ckpt-1


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 2 of 20
Started: 06/18/2021, 15:07:21


100%|██████████| 508/508.0 [02:34<00:00,  3.29it/s]



Ended: 06/18/2021, 15:09:56
loss: 0.253737211227417

Checkpoint saved in: /content/ckpt/ckpt-2


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 3 of 20
Started: 06/18/2021, 15:09:58


100%|██████████| 508/508.0 [03:21<00:00,  2.52it/s]



Ended: 06/18/2021, 15:13:20
loss: 0.20248804986476898

Checkpoint saved in: /content/ckpt/ckpt-3


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 4 of 20
Started: 06/18/2021, 15:13:22


100%|██████████| 508/508.0 [02:35<00:00,  3.27it/s]



Ended: 06/18/2021, 15:15:58
loss: 0.1883743405342102

Checkpoint saved in: /content/ckpt/ckpt-4


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 5 of 20
Started: 06/18/2021, 15:15:58


100%|██████████| 508/508.0 [02:35<00:00,  3.26it/s]



Ended: 06/18/2021, 15:18:34
loss: 0.18069003522396088

Checkpoint saved in: /content/ckpt/ckpt-5


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 6 of 20
Started: 06/18/2021, 15:18:35


100%|██████████| 508/508.0 [02:34<00:00,  3.28it/s]



Ended: 06/18/2021, 15:21:10
loss: 0.1754618138074875

Checkpoint saved in: /content/ckpt/ckpt-6


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 7 of 20
Started: 06/18/2021, 15:21:10


100%|██████████| 508/508.0 [02:35<00:00,  3.27it/s]



Ended: 06/18/2021, 15:23:45
loss: 0.17095227539539337

Checkpoint saved in: /content/ckpt/ckpt-7


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 8 of 20
Started: 06/18/2021, 15:23:46


100%|██████████| 508/508.0 [02:35<00:00,  3.27it/s]



Ended: 06/18/2021, 15:26:22
loss: 0.16656070947647095

Checkpoint saved in: /content/ckpt/ckpt-8


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 9 of 20
Started: 06/18/2021, 15:26:22


100%|██████████| 508/508.0 [02:35<00:00,  3.26it/s]



Ended: 06/18/2021, 15:28:58
loss: 0.16239042580127716

Checkpoint saved in: /content/ckpt/ckpt-9


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 10 of 20
Started: 06/18/2021, 15:28:58


100%|██████████| 508/508.0 [02:35<00:00,  3.28it/s]



Ended: 06/18/2021, 15:31:33
loss: 0.1580495685338974

Checkpoint saved in: /content/ckpt/ckpt-10


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 11 of 20
Started: 06/18/2021, 15:31:34


100%|██████████| 508/508.0 [02:34<00:00,  3.28it/s]



Ended: 06/18/2021, 15:34:09
loss: 0.15401937067508698

Checkpoint saved in: /content/ckpt/ckpt-11


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 12 of 20
Started: 06/18/2021, 15:34:09


100%|██████████| 508/508.0 [02:34<00:00,  3.28it/s]



Ended: 06/18/2021, 15:36:44
loss: 0.15002867579460144

Checkpoint saved in: /content/ckpt/ckpt-12


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 13 of 20
Started: 06/18/2021, 15:36:44


100%|██████████| 508/508.0 [02:34<00:00,  3.29it/s]



Ended: 06/18/2021, 15:39:19
loss: 0.14651882648468018

Checkpoint saved in: /content/ckpt/ckpt-13


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 14 of 20
Started: 06/18/2021, 15:39:19


100%|██████████| 508/508.0 [02:34<00:00,  3.29it/s]



Ended: 06/18/2021, 15:41:54
loss: 0.14311684668064117

Checkpoint saved in: /content/ckpt/ckpt-14


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 15 of 20
Started: 06/18/2021, 15:41:54


100%|██████████| 508/508.0 [02:35<00:00,  3.27it/s]



Ended: 06/18/2021, 15:44:30
loss: 0.1404239386320114

Checkpoint saved in: /content/ckpt/ckpt-15


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 16 of 20
Started: 06/18/2021, 15:44:30


100%|██████████| 508/508.0 [02:34<00:00,  3.29it/s]



Ended: 06/18/2021, 15:47:05
loss: 0.1380797028541565

Checkpoint saved in: /content/ckpt/ckpt-16


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 17 of 20
Started: 06/18/2021, 15:47:05


100%|██████████| 508/508.0 [02:34<00:00,  3.28it/s]



Ended: 06/18/2021, 15:49:40
loss: 0.13549792766571045

Checkpoint saved in: /content/ckpt/ckpt-17


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 18 of 20
Started: 06/18/2021, 15:49:40


100%|██████████| 508/508.0 [02:34<00:00,  3.30it/s]



Ended: 06/18/2021, 15:52:15
loss: 0.13324987888336182

Checkpoint saved in: /content/ckpt/ckpt-18


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 19 of 20
Started: 06/18/2021, 15:52:15


100%|██████████| 508/508.0 [02:35<00:00,  3.26it/s]



Ended: 06/18/2021, 15:54:51
loss: 0.13152816891670227

Checkpoint saved in: /content/ckpt/ckpt-19


  0%|          | 0/508.0 [00:00<?, ?it/s]


Epoch: 20 of 20
Started: 06/18/2021, 15:54:51


100%|██████████| 508/508.0 [02:36<00:00,  3.25it/s]



Ended: 06/18/2021, 15:57:28
loss: 0.1298578530550003

Checkpoint saved in: /content/ckpt/ckpt-20


In [None]:
ae.save(ckpt_dir + '/sketch_colorization_ae.h5')
model_artifact = wandb.Artifact("sketch_colorization_AE", type="model", 
                                description="autoencoder model checkpoint")
model_artifact.add_file(ckpt_dir + '/sketch_colorization_ae.h5')
run.log_artifact(model_artifact)



<wandb.sdk.wandb_artifacts.Artifact at 0x7f2360066dd0>