In [1]:
import tensorflow as tf


gpus = tf.config.list_physical_devices("GPU")
tf.config.set_visible_devices([
    gpus[1],
    gpus[2],
], "GPU")
gpus = tf.config.get_visible_devices("GPU")
gpus

2022-05-01 14:21:42.837988: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1917] Ignoring visible gpu device (device: 4, name: Quadro P1000, pci bus id: 0000:c8:00.0, compute capability: 6.1) with core count: 5. The minimum required count is 8. You can adjust this requirement with the env var TF_MIN_GPU_MULTIPROCESSOR_COUNT.


[PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]

In [2]:
print(f'Detected gpus: {gpus}')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
print('Set dynamic GPU memory allocation.')

Detected gpus: [PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]
Set dynamic GPU memory allocation.


In [3]:
model_name_base = "gan_v01"

In [4]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "colorgan.ipynb"

import wandb
wandb.init(project="colorgan", tags=["gan"], name=model_name_base)

[34m[1mwandb[0m: Currently logged in as: [33mmarkcheeky[0m (use `wandb login --relogin` to force relogin)


In [5]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from utils import PROJ_ROOT
from callbacks import LogPredictionsCallback
from models import get_unet_generator, get_discriminator, ColorGan
from dataset import postprocess, folder_dataset

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, Reduction
from tensorflow.data.experimental import AutoShardPolicy

from wandb.keras import WandbCallback

In [6]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


2022-05-01 14:21:48.669435: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-01 14:21:49.743186: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79124 MB memory:  -> device: 1, name: NVIDIA A100-SXM-80GB, pci bus id: 0000:46:00.0, compute capability: 8.0
2022-05-01 14:21:49.744067: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 79124 MB memory:  -> device: 2, name: NVIDIA A100-SXM-80GB, pci bus id: 0000:85:00.0, compute capability: 8.0


In [7]:
BATCH_SIZE_LOCAL = 128
BATCH_SIZE = BATCH_SIZE_LOCAL * strategy.num_replicas_in_sync
PREFETCH = tf.data.AUTOTUNE

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA

ds_train = folder_dataset(
    f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/train",
    augment=True,
    img_size=(512, 512),
    batch_size=BATCH_SIZE,
    prefetch=PREFETCH,
)

ds_monitor = (
    folder_dataset(
        f"{PROJ_ROOT}/imagenet/ILSVRC/Data/CLS-LOC/val",
        augment=False,
        img_size=(512, 512),
        batch_size=1,
    )
    .unbatch()
    .take(2000)
    .shuffle(buffer_size=500, seed=1)
    .take(128)
).batch(BATCH_SIZE_LOCAL).cache().with_options(options)

Found 1281167 files belonging to 1 classes.
Found 50000 files belonging to 1 classes.


In [8]:
with strategy.scope():
    g = get_unet_generator()
    d = get_discriminator()
    gan = ColorGan(g, d)

    end_loss_base = BinaryCrossentropy(label_smoothing=0.2, reduction=Reduction.NONE)
    end_loss = lambda labels, preds: tf.reduce_sum(end_loss_base(labels, preds)) / BATCH_SIZE

    gan.compile(
        d_optimizer=Adam(2e-4, beta_1=0.5),
        g_optimizer=Adam(2e-4, beta_1=0.5),
        end_loss=end_loss,
    )

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

In [9]:
model_name = model_name_base + "_epoch{epoch:02d}"
model_path = f"{PROJ_ROOT}/models/{model_name}"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_freq="epoch",
)

wandb_callback = WandbCallback()

visualization_callback = LogPredictionsCallback(ds_monitor, every_n_batch=500)

In [10]:
with strategy.scope():
    history = gan.fit(
        ds_train,
        epochs=10,
        verbose=1,
        callbacks=[
            model_checkpoint_callback,
            wandb_callback,
            visualization_callback,
        ]
    )

2022-05-01 14:22:10.057728: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_1"
op: "TensorSliceDataset"
input: "Placeholder/_0"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
    }
  }
}

2022-05-01 14:22:10.100466: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Epoch 1/10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:GPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1').
INFO:tensorflow:Reduce to /job:localhost/replica:0/tas

2022-05-01 14:22:32.538044: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8201
2022-05-01 14:22:33.795055: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8201
2022-05-01 14:22:34.203960: W tensorflow/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.0
2022-05-01 14:22:34.203987: W tensorflow/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2022-05-01 14:22:34.204079: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] Unimplemented: ptxas ptxas too old. Falling back to the driver to compile.
Relying on driver to perform ptx compilation. 
Modify $PATH to customize ptxas location.
This message will be only logged once.
2022-05-01 14:22:55.330592: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


 236/6406 [>.............................] - ETA: 2:38:00 - d_loss: 0.3910 - g_loss: 15.3784

KeyboardInterrupt: 

In [None]:
history_dir = f'{PROJ_ROOT}/train_history/{model_name_base}'
os.makedirs(history_dir)

In [None]:
import pickle
with open(f"{history_dir}/history.pkl", 'wb') as file:
    pickle.dump(history.history, file)