In [1]:
import numpy as np
import tensorflow as tf

from PIL import Image
from stylegan2.utils import postprocess_images, adjust_dynamic_range
from stylegan2.generator import Generator

import matplotlib.pyplot as plt
import os
import time
from glob import glob

In [2]:
# model = tf.keras.models.load_model("C:/repos/gan-refinement/nets/lutz_new_classifier_tf1.14.h5")
# model.summary()

In [3]:
# w = np.abs(np.mean(model.trainable_weights[0].numpy().reshape((1024, 1024, 3)), axis=-1))
# w = np.clip(w, 0, np.quantile(w, 0.999))
# plt.axis("off")
# plt.imshow(w, cmap=plt.cm.inferno)
# plt.colorbar()

In [4]:
IMAGES_DIR = "D:/datasets/adv_images_eps32"
RESULT_DIR = "models/"
assert os.path.isdir(IMAGES_DIR) and os.path.isdir(RESULT_DIR)

In [5]:
EPOCHS           = 20
BATCH_SIZE       = 4
BATCH_PER_GPU    = 4
VALIDATION_SPLIT = 0.2

AUTOTUNE         = tf.data.experimental.AUTOTUNE

assert BATCH_SIZE % BATCH_PER_GPU == 0

In [6]:
def build_generator(g_params):
    ### Taken from inference_from_official_weights.py
    # prepare variables & construct generator
    
    g_clone = Generator(g_params)

    # finalize model (build)
    test_latent = np.ones((1, g_params['z_dim']), dtype=np.float32)
    test_labels = np.ones((1, g_params['labels_dim']), dtype=np.float32)
    _ = g_clone([test_latent, None], training=False)
    _ = g_clone([test_latent, None], training=True)

    # restore
    ckpt_dir = './official-converted'
    ckpt = tf.train.Checkpoint(g_clone=g_clone)
    manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print('Restored from {}'.format(manager.latest_checkpoint))

    return g_clone

In [7]:
g_params = {
        'z_dim': 512,
        'w_dim': 512,
        'labels_dim': 0,
        'n_mapping': 8,
        'resolutions': [4, 8, 16, 32, 64, 128, 256, 512, 1024],
        'featuremaps': [512, 512, 512, 512, 512, 256, 128, 64, 32],
        'w_ema_decay': 0.995,
        'style_mixing_prob': 0.0,
        'randomize_noise': False,
    }

In [8]:
generator = build_generator(g_params)
generator.summary()

Restored from ./official-converted\ckpt-0
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
g_mapping (Mapping)          multiple                  2101248   
_________________________________________________________________
lambda_1 (Lambda)            multiple                  0         
_________________________________________________________________
g_synthesis (Synthesis)      multiple                  28268812  
Total params: 30,370,572
Trainable params: 30,370,060
Non-trainable params: 512
_________________________________________________________________


In [9]:
def generate_image(seed, generator, latent_dim):
    rnd = np.random.RandomState(seed)
    latents = rnd.randn(1, latent_dim).astype(np.float32)
    image_out, _ = generator([latents, None], training=False, truncation_psi=1.0)
    image_out = postprocess_images(image_out).numpy()
    return image_out[0]

In [10]:
def generate_save_images(save_path, generator, latent_dim, num_images=10_000, start_seed=0):
    assert os.path.isdir(save_path)

    for idx, seed in enumerate(range(0, num_images)):
        image = generate_image(seed, generator,latent_dim)
        Image.fromarray(image, 'RGB').save(os.path.join(save_path, f'seed{seed:05d}.png'))
        print(f"\rImage {idx + 1}/{num_images}", end="")

# generate_save_images("D:/datasets/test_images", generator, g_params["z_dim"], num_images=100)

In [11]:
def noise_from_path(path, latent_dim):
    name = os.path.splitext(os.path.basename(path))[0]
    seed = int(name.split("seed")[1])
    rand = np.random.RandomState(seed)
    return rand.randn(latent_dim).astype(np.float32)

image_paths = glob(os.path.join(IMAGES_DIR, "*.png"))
noise_np = np.asarray([noise_from_path(path, g_params["z_dim"]) for path in image_paths])
assert len(image_paths) == len(noise_np)

In [12]:
image_shape = 2 * (g_params["resolutions"][-1],) + (3,)
print(f"Image shape is assumed as {image_shape}")

def process_sample(noise, path):
    # Load image
    image = tf.io.read_file(path)
    image = tf.io.decode_image(image)
    image = tf.cast(image, tf.float32)
    image.set_shape(image_shape) # fix tensorflow issue 24520

    # Preprocess to net input range [-1, 1]
    image = adjust_dynamic_range(image, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32)
    image = tf.transpose(image, [2, 0, 1])

    return noise, image

Image shape is assumed as (1024, 1024, 3)


In [13]:
dataset = tf.data.Dataset.from_tensor_slices((noise_np, image_paths))
dataset.element_spec

(TensorSpec(shape=(512,), dtype=tf.float32, name=None),
 TensorSpec(shape=(), dtype=tf.string, name=None))

In [14]:
def test_dataset(dataset, grid=(4,4)):
    assert len(grid) == 2 and isinstance(grid, (tuple, list)) and grid[1] % 2 == 0   
    nois, imgs = next(iter(dataset.batch(8)))
    gens, _ = generator([nois, None], training=False)
    gens = postprocess_images(gens).numpy()
    imgs = postprocess_images(imgs).numpy()

    num_pairs = grid[0]*grid[1] // 2
    assert num_pairs <= gens.shape[0]

    plt.figure(figsize=(grid[0] * 2, grid[1] * 2))
    for idx in range(num_pairs):
        plt.subplot(4, 4, idx * 2 + 1)
        plt.axis("off")
        plt.imshow(gens[idx], cmap="gray")

        plt.subplot(4, 4, idx * 2 + 2)
        plt.axis("off")
        plt.imshow(imgs[idx], cmap="gray")

# test_dataset(dataset.map(process_sample, AUTOTUNE))

In [15]:
TRAIN_SIZE = 1000#len(image_paths)
VAL_SIZE = 200

train_dataset = dataset.take(TRAIN_SIZE).shuffle(TRAIN_SIZE).map(process_sample, AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
train_dataset.element_spec

(TensorSpec(shape=(None, 512), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 3, 1024, 1024), dtype=tf.float32, name=None))

In [16]:
val_dataset = dataset.skip(TRAIN_SIZE).take(VAL_SIZE).shuffle(VAL_SIZE).map(process_sample, AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

In [17]:
# Print names of all trainable variables
[var.name for var in generator.trainable_variables]

['g_mapping/dense_0/w:0',
 'g_mapping/dense_1/w:0',
 'g_mapping/dense_2/w:0',
 'g_mapping/dense_3/w:0',
 'g_mapping/dense_4/w:0',
 'g_mapping/dense_5/w:0',
 'g_mapping/dense_6/w:0',
 'g_mapping/dense_7/w:0',
 'g_mapping/bias_0/b:0',
 'g_mapping/bias_1/b:0',
 'g_mapping/bias_2/b:0',
 'g_mapping/bias_3/b:0',
 'g_mapping/bias_4/b:0',
 'g_mapping/bias_5/b:0',
 'g_mapping/bias_6/b:0',
 'g_mapping/bias_7/b:0',
 'g_synthesis/4x4/const/const:0',
 'g_synthesis/4x4/const/conv/w:0',
 'g_synthesis/4x4/const/conv/mod_dense/w:0',
 'g_synthesis/4x4/const/conv/mod_bias/b:0',
 'g_synthesis/4x4/const/noise/w:0',
 'g_synthesis/4x4/const/bias/b:0',
 'g_synthesis/4x4/ToRGB/conv/w:0',
 'g_synthesis/4x4/ToRGB/conv/mod_dense/w:0',
 'g_synthesis/4x4/ToRGB/conv/mod_bias/b:0',
 'g_synthesis/4x4/ToRGB/bias/b:0',
 'g_synthesis/8x8/block/conv_0/w:0',
 'g_synthesis/8x8/block/conv_0/mod_dense/w:0',
 'g_synthesis/8x8/block/conv_0/mod_bias/b:0',
 'g_synthesis/8x8/block/noise_0/w:0',
 'g_synthesis/8x8/block/bias_0/b:0',

In [18]:
def freeze_vars(gen, freeze_layers, freeze_mapping=True, freeze_synthesis=False):
    if freeze_mapping: gen.layers[0].trainable = False; print("Freezed mapping")
    if freeze_synthesis: gen.layers[2].trainable = False; print("Freezed synthesis")
    train_vars = []
    for var in gen.trainable_variables:
        if any(layer in var.name for layer in freeze_layers):
            print(f"Freezed {var.name}")
        else:
            train_vars.append(var)
    print(f"{len(gen.trainable_variables) - len(train_vars)} variables were freezed")
    return train_vars

trainable_variables = freeze_vars(generator, ["4x4", "8x8", "16x16", "32x32", "64x64"])
len(trainable_variables)

Freezed mapping
Freezed g_synthesis/4x4/const/const:0
Freezed g_synthesis/4x4/const/conv/w:0
Freezed g_synthesis/4x4/const/conv/mod_dense/w:0
Freezed g_synthesis/4x4/const/conv/mod_bias/b:0
Freezed g_synthesis/4x4/const/noise/w:0
Freezed g_synthesis/4x4/const/bias/b:0
Freezed g_synthesis/4x4/ToRGB/conv/w:0
Freezed g_synthesis/4x4/ToRGB/conv/mod_dense/w:0
Freezed g_synthesis/4x4/ToRGB/conv/mod_bias/b:0
Freezed g_synthesis/4x4/ToRGB/bias/b:0
Freezed g_synthesis/8x8/block/conv_0/w:0
Freezed g_synthesis/8x8/block/conv_0/mod_dense/w:0
Freezed g_synthesis/8x8/block/conv_0/mod_bias/b:0
Freezed g_synthesis/8x8/block/noise_0/w:0
Freezed g_synthesis/8x8/block/bias_0/b:0
Freezed g_synthesis/8x8/block/conv_1/w:0
Freezed g_synthesis/8x8/block/conv_1/mod_dense/w:0
Freezed g_synthesis/8x8/block/conv_1/mod_bias/b:0
Freezed g_synthesis/8x8/block/noise_1/w:0
Freezed g_synthesis/8x8/block/bias_1/b:0
Freezed g_synthesis/16x16/block/conv_0/w:0
Freezed g_synthesis/16x16/block/conv_0/mod_dense/w:0
Freezed g_

56

In [19]:
g_optimizer = tf.keras.optimizers.Adam(1e-3)
mean_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

@tf.function
def mean_squared_error(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred), axis=1) # [batch, c, h, w]

@tf.function
def mean_absolute_error(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred), axis=1) # [batch, c, h, w]

@tf.function
def train_step(latents, target_images):
    accumulated_gradients = []
    rounds = BATCH_SIZE // BATCH_PER_GPU

    for start in range(0, BATCH_SIZE, BATCH_PER_GPU):
        end = start +  BATCH_PER_GPU

        with tf.GradientTape() as g_tape:
            # Forward pass
            fake_images, _ = generator([latents[start:end], None], training=False) # Deactivate style mixing etc
            fake_images = tf.clip_by_value(fake_images, -1., 1.)

            # Loss
            g_loss = mean_absolute_error(target_images[start:end], fake_images)
        
        # Gradient accumulation
        g_gradients = g_tape.gradient(g_loss, trainable_variables)
        if start == 0:
            accumulated_gradients = g_gradients
        else:
            accumulated_gradients = [ac_grad + grad for ac_grad, grad in zip(accumulated_gradients, g_gradients)]

        # Metric update
        mean_loss.update_state(tf.reduce_mean(g_loss, axis=[1, 2]))

    # Average accumulated gradients and apply
    if rounds > 1: accumulated_gradients = [ac_grad/rounds for ac_grad in accumulated_gradients]
    g_optimizer.apply_gradients(zip(accumulated_gradients, trainable_variables))

In [20]:
mean_val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)

def val_step(latents, target_images):
    for start in range(0, BATCH_SIZE, BATCH_PER_GPU):
        end = start +  BATCH_PER_GPU

        # Forward pass
        fake_images, _ = generator([latents[start:end], None], training=False) # Deactivate style mixing etc
        fake_images = tf.clip_by_value(fake_images, -1., 1.)

        # Loss
        g_loss = mean_squared_error(target_images[start:end], fake_images)

        # Metric update
        mean_val_loss.update_state(tf.reduce_mean(g_loss, axis=[1, 2]))

In [21]:
# num_steps = tf.data.experimental.cardinality(val_dataset).numpy()
# for step, (latents, images) in enumerate(val_dataset):
#     val_step(latents, images)
#     if step % 10 == 0: print(f"\r{step+10:04d}/{num_steps:04d} loss {mean_val_loss.result().numpy():.6f}", end="")

In [22]:
def train():
    print("Starting training...")
    num_steps = tf.data.experimental.cardinality(train_dataset).numpy()

    for epoch in range(EPOCHS):
        print(f"{epoch+1:02d}/{EPOCHS:02d}")
        start = time.time()

        for step, (latents, images) in enumerate(train_dataset):
            train_step(latents, images)

            if step % 10 == 0: print(f"\r{step:04d}/{num_steps:04d} loss {mean_loss.result().numpy():.6f}", end="")

        print (f'\r{time.time()-start:.3f} sec, avg loss {mean_loss.result().numpy():.6f}')
        mean_loss.reset_states()
        
        # # Save the model every 10 epochs
        # if (epoch + 1) % 10 == 0:
        #     generator.save(os.path.join(SAVE_DIR, "{:03d}_gen.h5".format(epoch + 1)),
        #         save_format="h5")
            
        #     # Produce images for the GIF as we go
        #     generate_images(epoch=epoch + 1)
        # else:
        #     generate_images()

    # # Generate after the final epoch
    # generate_images()

In [23]:
train()

Starting training...
01/20
198.220 sec, avg loss 0.236826
02/20
0120/0250 loss 0.236913

KeyboardInterrupt: 