In [None]:
import tensorflow as tf


gpus = tf.config.list_physical_devices("GPU")
tf.config.set_visible_devices([gpus[1]], "GPU")
gpus = tf.config.get_visible_devices("GPU")
gpus

In [None]:
print(f'Detected gpus: {gpus}')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
print('Set dynamic GPU memory allocation.')

In [None]:
model_name_base = "gan_lab_v02"

In [None]:
import tempfile
log_dir = f"{tempfile.gettempdir()}/xkadlci2_wandb"
checkpoint_dir = f"{tempfile.gettempdir()}/xkadlci2_checkpoints"
log_dir, checkpoint_dir

In [None]:
import os
os.makedirs(log_dir)
os.makedirs(checkpoint_dir)

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "colorgan_lab.ipynb"

import wandb
wandb.init(project="colorgan", tags=["gan_lab"], name=model_name_base, dir=log_dir)

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from utils import PROJ_ROOT
from callbacks import LogPredictionsCallback
from models import get_unet_generator, get_discriminator, ColorGan
from dataset import postprocess, postprocess_lab, folder_dataset

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, Reduction
# from tensorflow.data.experimental import AutoShardPolicy

from wandb.keras import WandbCallback

In [None]:
#strategy = tf.distribute.MirroredStrategy()

In [None]:
BATCH_SIZE_LOCAL = 64
BATCH_SIZE = BATCH_SIZE_LOCAL# * strategy.num_replicas_in_sync
PREFETCH = tf.data.AUTOTUNE

# options = tf.data.Options()
# options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA

ds_train = folder_dataset(
    f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/train",
    augment=True,
    img_size=(512, 512),
    batch_size=BATCH_SIZE,
    prefetch=PREFETCH,
    use_lab=True,
)


ds_monitor = (
    folder_dataset(
        f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/val",
        augment=False,
        img_size=(512, 512),
        batch_size=1,
        use_lab=True,
    )
    .unbatch()
    .take(2000)
    .shuffle(buffer_size=500, seed=1)
    .take(128)
).batch(BATCH_SIZE_LOCAL).cache()#.with_options(options)

In [None]:
wandb.config.batch_size = BATCH_SIZE
wandb.config.weight_mae_loss = 80
wandb.config.lr_dicriminator = 1e-5
wandb.config.lr_generator = 2e-4
wandb.config.label_smoothing = 0.2
wandb.config.log_loss_every_n_batch = 10
wandb.config.log_vis_every_n_batch = 200
wandb.config.param_hist_every_n_batch = 50
wandb.config.epochs = 5

In [None]:
# with strategy.scope():
g = get_unet_generator(use_lab=True)
d = get_discriminator(use_lab=True)
gan = ColorGan(g, d, weight_mae_loss=wandb.config.weight_mae_loss)

#end_loss_base = BinaryCrossentropy(label_smoothing=0.1, reduction=Reduction.NONE)
#end_loss = lambda labels, preds: tf.reduce_sum(end_loss_base(labels, preds)) / BATCH_SIZE

gan.compile(
    d_optimizer=Adam(wandb.config.lr_dicriminator, beta_1=0.5),
    g_optimizer=Adam(wandb.config.lr_generator, beta_1=0.5),
    end_loss=BinaryCrossentropy(label_smoothing=wandb.config.label_smoothing, from_logits=True),
)

In [None]:
model_name = model_name_base + "_epoch{epoch:02d}"
model_path = f"{checkpoint_dir}/{model_name}"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_freq="epoch",
)

wandb_callback = WandbCallback()

visualization_callback = LogPredictionsCallback(
    ds_monitor,
    vis_every_n_batch=wandb.config.log_vis_every_n_batch,
    loss_every_n_batch=wandb.config.log_loss_every_n_batch,
    param_hist_every_n_batch=wandb.config.param_hist_every_n_batch,
    use_lab=True,
)

In [None]:
#with strategy.scope():
history = gan.fit(
    ds_train,
    epochs=wandb.config.epochs,
    verbose=1,
    callbacks=[
        model_checkpoint_callback,
        wandb_callback,
        visualization_callback,
    ]
)