In [None]:
import tensorflow as tf


gpus = tf.config.list_physical_devices("GPU")
tf.config.set_visible_devices([gpus[1]], "GPU")
tf.config.get_visible_devices("GPU")

In [None]:
gpus = tf.config.get_visible_devices("GPU")

print(f'gpus: {gpus}')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
print('Set dynamic GPU memory allocation.')

In [None]:
import os
from utils import PROJ_ROOT
os.environ["WANDB_NOTEBOOK_NAME"] = "baseline.ipynb"

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.init(project="colorgan", tags=["baseline"])

In [None]:
import numpy as np
import pandas as pd

from models import get_unet_generator
from dataset import folder_dataset, postprocess
from utils import LogPredictionsCallback


In [None]:
BATCH_SIZE = 128
PREFETCH = tf.data.AUTOTUNE

ds_train = folder_dataset(
    f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/train",
    augment=True,
    img_size=(512, 512),
    batch_size=BATCH_SIZE,
    prefetch=PREFETCH,
)


ds_valid = (
    folder_dataset(
        f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/val",
        augment=False,
        img_size=(512, 512),
        batch_size=1,
    )
    .unbatch()
    .take(2000)
    .shuffle(buffer_size=500, seed=1)
    .take(BATCH_SIZE*100)
    .batch(1)
)

ds_valid = tf.data.Dataset.from_tensor_slices(tuple(map(np.concatenate, zip(*list(ds_valid))))).batch(BATCH_SIZE)

In [None]:
ds_monitor = ds_valid.take(1)
ds_monitor

In [None]:
import matplotlib.pyplot as plt
from skimage.io import imshow


limit = 6
f, axarrx = plt.subplots(1, limit, figsize=(limit*4, 4)) 
f, axarry = plt.subplots(1, limit, figsize=(limit*4, 4)) 

for i, (x, y) in enumerate(ds_monitor.unbatch().take(limit)):
    axarrx[i].imshow(postprocess(x.numpy()))
    axarry[i].imshow(postprocess(y.numpy()))

In [None]:
g = get_unet_generator()

In [None]:
optim = tf.keras.optimizers.Adam(learning_rate=0.0001)

In [None]:
g.compile(optimizer=optim, loss="mae")

In [None]:
preds = g.predict(ds_monitor)

In [None]:
limit = 6
f, axarr = plt.subplots(1, limit, figsize=(limit*4, 4))

for i, pred in enumerate(preds[:limit]):
    axarr[i].imshow(postprocess(pred))

In [None]:
model_name = "baseline_v2"
model_path = f"{PROJ_ROOT}/models/{model_name}"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    monitor="val_loss",
    save_best_only=True,
    save_freq=1000,
)

wandb_callback = WandbCallback()

visualization_callback = LogPredictionsCallback(ds_monitor, every_n_batch=200)

In [None]:
history = g.fit(
    ds_train,
    epochs=10,
    validation_data=ds_valid,
    validation_freq=1,
    verbose=1,
    callbacks=[
        model_checkpoint_callback,
        wandb_callback,  
        visualization_callback
    ]
)

In [None]:
g.save(f"{model_path}_final")