In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
import sys
sys.path.insert(0, "./")
import tensorflow as tf
import basic_block
import time
import os
import load_data

In [49]:

def train(style_list, content_list, save_model_dir,batch_size=1, num_epochs=1, style_weight=1, content_weight=1,
          ngf=64, log_interval=1):

    ########################
    # Data loader
    ########################
    
    content_loader = load_data.get_dataloader(content_list, batch_size, target_size=(400,400))
    style_loader = load_data.get_dataloader(style_list, batch_size, target_size=(400, 400))
    content_loader_iter = iter(content_loader)
    style_loader_iter = iter(style_loader)

    ########################
    # Init model
    ########################  
    vgg = basic_block.Vgg()
    style_model = basic_block.Net(ngf = ngf)
    ########################
    # optimizer and loss
    ########################
    mse_loss = tf.keras.losses.mean_squared_error
    optimizer = tf.keras.optimizers.Adam()
    ########################
    # Start training loop
    ########################
    for epoch in range(num_epochs):
        agg_content_loss = 0.0
        agg_style_loss = 0.0
        count = 0
        for batch_id, content_img in enumerate(content_loader_iter):
            print(batch_id)
            with tf.GradientTape() as tape:
                n_batch = len(content_img)
                count += n_batch
                # data preparation. TODO: figure out these helper functions
                style_image = next(style_loader_iter)
                #style_v = utils.subtract_imagenet_mean_preprocess_batch(style_image.copy())

                feature_style = vgg(style_image)
                gram_style = [basic_block.gram_matrix(y) for y in feature_style]

                f_xc_c = vgg(content_img)[1]

                style_model.set_target(style_image)
                y = style_model(content_img)
                features_y = vgg(y)
                
                print(y.shape, content_img.shape)
                print(features_y[1].shape, f_xc_c.shape)
                
                # TODO: why the coefficient 2?
                content_loss = 2 * content_weight * mse_loss(features_y[1], f_xc_c)

                style_loss = 0.0
                for m in range(len(features_y)):
                    gram_y = basic_block.gram_matrix(features_y[m])
                    _, C, _ = gram_style[m].shape
                    gram_s = tf.expand_dims(gram_style[m], 0).broadcast_to(batch_size, 1, C, C)
                    style_loss += 2 * style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
                total_loss = content_loss + style_loss
                agg_content_loss += content_loss[0]
                agg_style_loss += style_loss[0]
            gradients = tape.gradient(total_loss, style_model.variables)
            optimizer.apply_gradients(zip(gradients, style_model.trainable_variables))

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.ctime(), epoch + 1,
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if (batch_id + 1) % (4 * log_interval) == 0:
                # save model
                save_model_filename = "Epoch_" + str(epoch) + "iters_" + \
                    str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    content_weight) + "_" + str(style_weight) + ".params"
                save_model_path = os.path.join(save_model_dir, save_model_filename)
                tf.saved_model.save(style_model, save_model_path)
                print("\nCheckpoint, trained model saved at", save_model_path)




In [50]:
style_list = load_data.get_data_paths("content")
content_list = load_data.get_data_paths("style")
train(style_list, content_list, "models")

0
(1, 396, 396, 64) (1, 198, 198, 128) (1, 198, 198, 128)
(1, 198, 198, 128) (1, 99, 99, 256) (1, 99, 99, 256)
downsample: (1, 99, 99, 256)
(1, 396, 396, 64) (1, 198, 198, 128) (1, 198, 198, 128)
(1, 198, 198, 128) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 99, 99, 256) (1, 99, 99, 256)
(1, 99, 99, 256) (1, 198, 198, 128) (1, 198, 198, 128)
(1, 198, 198, 128) (1, 396, 396, 64) (1, 396, 396, 64)
upsample: (1, 396, 396, 3)
(1, 396, 396, 3) (1, 400, 400, 3)
(1, 198, 198, 128) (1, 200, 200, 128)


InvalidArgumentError: Incompatible shapes: [1,200,200,128] vs. [1,198,198,128] [Op:SquaredDifference]