In [None]:
import tensorflow as tf


gpus = tf.config.list_physical_devices("GPU")
tf.config.set_visible_devices([
    gpus[1],
    gpus[2],
], "GPU")
gpus = tf.config.get_visible_devices("GPU")
gpus

In [None]:
print(f'Detected gpus: {gpus}')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
print('Set dynamic GPU memory allocation.')

In [None]:
model_name_base = "gan_v02"

In [None]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "colorgan.ipynb"

import wandb
wandb.init(project="colorgan", tags=["gan"], name=model_name_base)

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from utils import PROJ_ROOT
from callbacks import LogPredictionsCallback
from models import get_unet_generator, get_discriminator, ColorGan
from dataset import postprocess, folder_dataset

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, Reduction
from tensorflow.data.experimental import AutoShardPolicy

from wandb.keras import WandbCallback

In [None]:
strategy = tf.distribute.MirroredStrategy()

In [None]:
BATCH_SIZE_LOCAL = 128
BATCH_SIZE = BATCH_SIZE_LOCAL * strategy.num_replicas_in_sync
PREFETCH = tf.data.AUTOTUNE

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA

ds_train = folder_dataset(
    f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/train",
    augment=True,
    img_size=(512, 512),
    batch_size=BATCH_SIZE,
    prefetch=PREFETCH,
)

ds_monitor = (
    folder_dataset(
        f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/val",
        augment=False,
        img_size=(512, 512),
        batch_size=1,
    )
    .unbatch()
    .take(2000)
    .shuffle(buffer_size=500, seed=1)
    .take(128)
).batch(BATCH_SIZE_LOCAL).cache().with_options(options)

In [None]:
with strategy.scope():
    g = get_unet_generator()
    d = get_discriminator()
    gan = ColorGan(g, d)

    end_loss_base = BinaryCrossentropy(label_smoothing=0.1, reduction=Reduction.NONE)
    end_loss = lambda labels, preds: tf.reduce_sum(end_loss_base(labels, preds)) / BATCH_SIZE

    gan.compile(
        d_optimizer=Adam(2e-4, beta_1=0.5),
        g_optimizer=Adam(2e-4, beta_1=0.5),
        end_loss=end_loss,
    )

In [None]:
model_name = model_name_base + "_epoch{epoch:02d}"
model_path = f"{PROJ_ROOT}/models/{model_name}"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_freq="epoch",
)

wandb_callback = WandbCallback()

visualization_callback = LogPredictionsCallback(ds_monitor, every_n_batch=200)

In [None]:
with strategy.scope():
    history = gan.fit(
        ds_train,
        epochs=10,
        verbose=1,
        callbacks=[
            model_checkpoint_callback,
            wandb_callback,
            visualization_callback,
        ]
    )

In [None]:
history_dir = f'{PROJ_ROOT}/train_history/{model_name_base}'
os.makedirs(history_dir)

In [None]:
import pickle
with open(f"{history_dir}/history.pkl", 'wb') as file:
    pickle.dump(history.history, file)