## **From Here to There: Video Inbetweening Using Direct 3D Convolutions**

Li, Y., Roblek, D., & Tagliasacchi, M. (2019). From here to there: Video inbetweening using direct 3d convolutions. arXiv preprint arXiv:1905.10240.

In [1]:
import tensorflow as tf
tf.__version__

'2.4.1'

In [2]:
!nvidia-smi

Mon Feb  8 02:23:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    24W / 300W |      0MiB / 16130MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## **Hyperparameters**

In [3]:
import datetime
import os
import time

In [4]:
class HParams(object):
    def __init__(self):
        ## Noise vector.
        self.D = 120

        ## Original image/video.
        self.T = 16
        self.H_0 = 64
        self.W_0 = 64
        self.channels = 3

        ## Feature map.
        self.H = 8
        self.W = 8
        self.C = 64

        self.L = 24
        
        ## Image/video/feature map size.
        self.u_sz     = [self.D, ]
        self.image_sz = [self.H_0, self.W_0, self.channels]
        self.video_sz = [self.T] + self.image_sz

        self.E_x_sz   = [self.H, self.W, self.C]
        self.z_sz     = [self.T, self.H, self.W, self.C]

        ## Train.
        self.epochs = 5
        self.batch_sz = 32
        
HPARAMS = HParams()

## **Model Architecture**

### **Generator**

In [5]:
def Conv2D_BN_LeakyReLU(
    x, 
    filters, 
    kernel_sz, 
    strides = 1, 
    padding = "same",
):
    x = tf.keras.layers.Conv2D(filters, kernel_sz, strides = strides, padding = padding)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(tf.nn.leaky_relu)(x)
    return x


def ImageEncoder(
    model_name = "ImageEncoder",
):
    model_input = tf.keras.layers.Input(shape = HPARAMS.image_sz, dtype = tf.dtypes.float32)
    
    args = [
        [ 64, 4, 2],
        [ 64, 3, 1],
        [128, 4, 2],
        [128, 3, 1],
        [256, 4, 2],
        [256, 3, 1],
        [ 64, 3, 1]]

    x = model_input

    ## L1 to L7.
    for (filters, kernel_sz, strides) in args:
        x = Conv2D_BN_LeakyReLU(x, filters, kernel_sz, strides = strides)

    model_output = x

    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name)

In [6]:
tmp = ImageEncoder()
tmp.summary()

Model: "ImageEncoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 64)        3136      
_________________________________________________________________
batch_normalization (BatchNo (None, 32, 32, 64)        256       
_________________________________________________________________
activation (Activation)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
batch_normalization_1 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 64)       

In [None]:
# tf.keras.utils.plot_model(tmp, show_shapes = True)

In [7]:
del tmp

In [9]:
def Conv3D_T_BN_LeakyReLU(
    x, 
    filters, 
    kernel_sz, 
    strides = 1, 
    padding = "same"
):
    x = tf.keras.layers.Conv3DTranspose(filters, kernel_sz, strides = strides, padding = padding)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(tf.nn.leaky_relu)(x)
    return x


def VideoGenerator(
    model_name = "VideoGenerator",
):
    model_input = tf.keras.layers.Input(shape = HPARAMS.z_sz, dtype = tf.dtypes.float32)

    args = [
        [256, (3, 3, 3), (1, 1, 1)],
        [256, (3, 3, 3), (1, 1, 1)],
        [128, (3, 4, 4), (1, 2, 2)],
        [128, (3, 3, 3), (1, 1, 1)],
        [ 64, (3, 4, 4), (1, 2, 2)],
        [ 64, (3, 3, 4), (1, 1, 1)],
        [  3, (3, 4, 4), (1, 2, 2)]]
    
    x = model_input

    ## L1 to L7.
    for (filters, kernel_sz, strides) in args:
        x = Conv3D_T_BN_LeakyReLU(x, filters, kernel_sz, strides = strides)

    model_output = x

    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name)

In [10]:
tmp = VideoGenerator()
tmp.summary()

Model: "VideoGenerator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 16, 8, 8, 64)]    0         
_________________________________________________________________
conv3d_transpose (Conv3DTran (None, 16, 8, 8, 256)     442624    
_________________________________________________________________
batch_normalization_7 (Batch (None, 16, 8, 8, 256)     1024      
_________________________________________________________________
activation_7 (Activation)    (None, 16, 8, 8, 256)     0         
_________________________________________________________________
conv3d_transpose_1 (Conv3DTr (None, 16, 8, 8, 256)     1769728   
_________________________________________________________________
batch_normalization_8 (Batch (None, 16, 8, 8, 256)     1024      
_________________________________________________________________
activation_8 (Activation)    (None, 16, 8, 8, 256)  

In [None]:
# tf.keras.utils.plot_model(tmp, show_shapes = True)

In [11]:
del tmp

In [12]:
class LatentRepresentationGeneratorBlock(tf.keras.Model):
    def __init__(
        self, 
        l, 
        T_l, 
        H = HPARAMS.H, 
        W = HPARAMS.W, 
        C = HPARAMS.C,
        model_name = "LatentRepresentationGeneratorBlock"
    ):
        super(LatentRepresentationGeneratorBlock, self).__init__(name = f"{model_name}_{l}")

        self.T_l = T_l
        self.H = H
        self.W = W
        self.C = C

        self.linear = tf.keras.layers.Dense(self.T_l * self.C)

        self.conv1d_s = tf.keras.layers.Conv1D(self.C, 3, strides = 1, padding = "same")
        self.conv1d_e = tf.keras.layers.Conv1D(self.C, 3, strides = 1, padding = "same")
        self.conv1d_n = tf.keras.layers.Conv1D(self.C, 3, strides = 1, padding = "same")

        self.conv3d_1 = tf.keras.layers.Conv3D(self.C, 3, strides = 1, padding = "same")
        self.conv3d_2 = tf.keras.layers.Conv3D(self.C, 3, strides = 1, padding = "same")


    def call(
        self, 
        u, 
        E_xs, 
        E_xe, 
        z_last
    ):        
        u_l = self.linear(u)
        u_l = tf.reshape(u_l, (-1, self.T_l, self.C)) ## [batch, T_l, C]

        g_e = tf.nn.sigmoid(self.conv1d_e(u_l))
        g_e = tf.tile(g_e[:, :, tf.newaxis, tf.newaxis, :], [1, 1, self.H, self.W, 1])

        g_s = tf.nn.sigmoid(self.conv1d_s(u_l))
        g_s = tf.tile(g_s[:, :, tf.newaxis, tf.newaxis, :], [1, 1, self.H, self.W, 1])
        
        x = tf.math.maximum(0., 1 - g_s - g_e)

        n = self.conv1d_n(u_l)
        n = tf.tile(n[:, :, tf.newaxis, tf.newaxis, :], [1, 1, self.H, self.W, 1])

        E_xs = tf.tile(E_xs[:, tf.newaxis, ...], [1, self.T_l, 1, 1, 1])
        E_xe = tf.tile(E_xe[:, tf.newaxis, ...], [1, self.T_l, 1, 1, 1])

        z = g_s * E_xs + g_e * E_xe + x * z_last + n
        z_residual = z
        
        z = tf.nn.leaky_relu(self.conv3d_1(z))
        z = tf.nn.leaky_relu(self.conv3d_2(z) + z_residual)

        return z
        

def LatentRepresentationGenerator(
    image_encoder,
    video_generator,
    T = HPARAMS.T,
    L = HPARAMS.L,
    model_name = "LatentRepresentationGenerator",
):
    model_input_1 = tf.keras.layers.Input(shape = HPARAMS.u_sz, dtype = tf.dtypes.float32) ## u
    model_input_2 = tf.keras.layers.Input(shape = HPARAMS.image_sz, dtype = tf.dtypes.float32) ## E(x_s)
    model_input_3 = tf.keras.layers.Input(shape = HPARAMS.image_sz, dtype = tf.dtypes.float32) ## E(x_e)

    u = model_input_1
    E_xs = image_encoder(model_input_2)
    E_xe = image_encoder(model_input_3)

    z_last = tf.keras.layers.Lambda(lambda xs: tf.stack(xs, axis = 1))([E_xs, E_xe])
    
    for l in range(L):
        if not (l % 8):
            z_last = tf.keras.layers.UpSampling3D((2, 1, 1))(z_last)

        T_l = int(T / 2 ** (2 - l // 8))
        z_last = LatentRepresentationGeneratorBlock(l, T_l)(u, E_xs, E_xe, z_last)

    model_output = video_generator(z_last)

    return tf.keras.Model(
        inputs = [model_input_1, model_input_2, model_input_3],
        outputs = model_output,
        name = model_name)

In [13]:
tmp = LatentRepresentationGenerator(
    ImageEncoder(), 
    VideoGenerator())

tmp.summary()

Model: "LatentRepresentationGenerator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
ImageEncoder (Functional)       (None, 8, 8, 64)     1584832     input_6[0][0]                    
                                                                 input_7[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 2, 8, 8, 64)  0           Image

In [None]:
# tf.keras.utils.plot_model(tmp, show_shapes = True, rankdir = "LR")

In [14]:
del tmp

### **Discriminator**

In [15]:
def Conv3D_LN_LeakyReLU(
    x, 
    filters, 
    kernel_sz, 
    strides = 1, 
    padding = "same"
):
    x = tf.keras.layers.Conv3D(filters, kernel_sz, strides = strides, padding = padding)(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Activation(tf.nn.leaky_relu)(x)
    return x


def VideoDiscriminator(
    model_name = "VideoDiscriminator",
):
    """MoCoGAN-style"""
    model_input = tf.keras.layers.Input(shape = HPARAMS.video_sz, dtype = tf.dtypes.float32)
    x = model_input

    args = [
        [ 64, 4, (1, 2, 2)],
        [128, 4, (1, 2, 2)],
        [256, 4, (1, 2, 2)],
        [512, 4, (1, 2, 2)]]
    
    ## L1 to L4.
    for (filters, kernel_sz, strides) in args:
        x = tf.keras.layers.ZeroPadding3D(padding = (0, 1, 1))(x)
        x = Conv3D_LN_LeakyReLU(x, filters, kernel_sz, strides = strides, padding = "valid")

    ## L5.
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1)(x)
    model_output = tf.keras.layers.Activation(tf.nn.sigmoid)(x)

    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name)

In [16]:
tmp = VideoDiscriminator()
tmp.summary()

Model: "VideoDiscriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 16, 64, 64, 3)]   0         
_________________________________________________________________
zero_padding3d (ZeroPadding3 (None, 16, 66, 66, 3)     0         
_________________________________________________________________
conv3d_48 (Conv3D)           (None, 13, 32, 32, 64)    12352     
_________________________________________________________________
layer_normalization (LayerNo (None, 13, 32, 32, 64)    128       
_________________________________________________________________
activation_28 (Activation)   (None, 13, 32, 32, 64)    0         
_________________________________________________________________
zero_padding3d_1 (ZeroPaddin (None, 13, 34, 34, 64)    0         
_________________________________________________________________
conv3d_49 (Conv3D)           (None, 10, 16, 16, 

In [None]:
# tf.keras.utils.plot_model(tmp, show_shapes = True)

In [17]:
del tmp

In [18]:
def Conv2D_LN_LeakyReLU(
    x, 
    filters, 
    kernel_sz, 
    strides = 1, 
    padding = "same"
):
    x = tf.keras.layers.Conv2D(filters, kernel_sz, strides = strides, padding = padding)(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Activation(tf.nn.leaky_relu)(x)
    return x


def Shortcut(
    x, 
    filters,
    kernel_sz = 1,
    pool_kernel_sz = 2,
    pool_strides = 2,
    pool_padding = "same"
):
    x = tf.keras.layers.AveragePooling2D(pool_kernel_sz, strides = pool_strides, padding = pool_padding)(x)
    x = tf.keras.layers.Conv2D(filters, kernel_sz)(x)
    return x


def ImageDiscriminator(
    model_name = "ImageDiscriminator",
):
    """Resnet-based"""
    model_input = tf.keras.layers.Input(shape = HPARAMS.image_sz, dtype = tf.dtypes.float32)

    ## L1.
    x = tf.keras.layers.Conv2D(3, 3, padding = "same")(model_input)

    ## L2 to L8.
    for filters in [64, 128, 256, 512]:
        residual = Shortcut(x, filters = filters)
        x = Conv2D_LN_LeakyReLU(x, filters, 4, strides = 2)
        x = Conv2D_LN_LeakyReLU(x, filters, 3, strides = 1)
        x = tf.keras.layers.Add()([x, residual])

    ## L9.
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1)(x)
    model_output = tf.keras.layers.Activation(tf.nn.sigmoid)(x)

    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name)

In [19]:
tmp = ImageDiscriminator()
tmp.summary()

Model: "ImageDiscriminator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 3)    84          input_9[0][0]                    
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 32, 32, 64)   3136        conv2d_14[0][0]                  
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 32, 32, 64)   128         conv2d_16[0][0]                  
_________________________________________________________________________________

In [None]:
# tf.keras.utils.plot_model(tmp, show_shapes = True)

In [20]:
del tmp

## **Loss Function**

In [55]:
def DiscriminatorLoss(
    real_output,
    generated_output,
    epsilon = 1e-7
):
    """Adopting the non-saturating log-loss for discriminators."""
    real_loss = -1 * tf.math.log(real_output + epsilon)
    generated_loss = tf.math.log(tf.ones_like(generated_output) - generated_output + epsilon)
    total_disc_loss = tf.math.reduce_mean(real_loss + generated_loss)

    return total_disc_loss


def GeneratorLoss(
    generated_video_output,
    generated_image_output,
    epsilon = 1e-7,
):
    """Loss function for encoder, feature map generator, and video generator."""
    generated_video_loss = -1 * tf.math.reduce_mean(tf.math.log(generated_video_output + epsilon))
    generated_image_loss = -1 * tf.math.reduce_mean(tf.math.log(generated_image_output + epsilon))
    total_gen_loss = generated_video_loss + generated_image_loss

    return total_gen_loss

In [93]:
## D_V
foo = tf.ones((32, 16, 64, 64, 3))  ## real_output
bar = tf.zeros((32, 16, 64, 64, 3)) ## generated_output

DiscriminatorLoss(foo, bar)

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [17]:
## D_I
foo = tf.ones((32, 14, 64, 64, 3))  ## real_output
bar = tf.zeros((32, 14, 64, 64, 3)) ## generated_output

DiscriminatorLoss(foo, bar)

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [18]:
## G := {E, G_Z, G_V}
foo = tf.ones((32, 16, 64, 64, 3))          ## generated_video_output
bar = tf.ones((32, 14, 64, 64, 3)) ## generated_image_output

GeneratorLoss(foo, bar)

<tf.Tensor: shape=(), dtype=float32, numpy=-2.3841855e-07>

## **Fit**

### **Generate Each Parts**

In [84]:
## Generator.
latent_representation_generator = LatentRepresentationGenerator(
    image_encoder = ImageEncoder(),
    video_generator = VideoGenerator())

## Discriminator.
video_discriminator = VideoDiscriminator()
image_discriminator = ImageDiscriminator()

### **Optimizers and Checkpoints**

In [85]:
generator_optimizer           = tf.keras.optimizers.Adam(lr = 5e-5, beta_1 = 0.5, beta_2 = 0.999, epsilon = 1e-8)
video_discriminator_optimizer = tf.keras.optimizers.Adam(lr = 5e-5, beta_1 = 0.5, beta_2 = 0.999, epsilon = 1e-8)
image_discriminator_optimizer = tf.keras.optimizers.Adam(lr = 5e-5, beta_1 = 0.5, beta_2 = 0.999, epsilon = 1e-8)

In [86]:
checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(
    ## Optimizers.
    generator_optimizer = generator_optimizer,
    video_discriminator_optimizer = video_discriminator_optimizer,
    image_discriminator_optimizer = image_discriminator_optimizer,

    ## Generators.
    latent_representation_generator = latent_representation_generator,
    
    ## Discriminators.
    video_discriminator = video_discriminator,
    image_discriminator = image_discriminator)

### **Train**

In [87]:
# !rm -rf logs

In [88]:
log_dir = "logs/"

summary_writer = tf.summary.create_file_writer(
    log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [89]:
@tf.function
def train_step(videos, epoch):
    noise = tf.random.normal([videos.shape[0], HPARAMS.D]) ## not [HPARAMS.batch_sz, HPARAMS.D]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as video_disc_tape, tf.GradientTape() as image_disc_tape:
        ## Split key frames.
        x_s = videos[:, 0]
        x_e = videos[:, -1]

        ## Genearte videos.
        generated_videos = latent_representation_generator([noise, x_s, x_e], training = True)

        ## Discriminate videos/images.
        video_disc_output_for_real = video_discriminator(videos, training = True)
        video_disc_output_for_gen  = video_discriminator(generated_videos, training = True)

        image_disc_output_for_real = tf.stack([
            image_discriminator(image, training = True) \
            for image in tf.unstack(videos, axis = 1)[1:-1]], axis = 1)
        image_disc_output_for_gen  = tf.stack([
            image_discriminator(generated_image, training = True) \
            for generated_image in tf.unstack(generated_videos, axis = 1)[1:-1]], axis = 1)
        
        ## Calculate losses.
        gen_loss = GeneratorLoss(video_disc_output_for_gen, image_disc_output_for_gen) 
        disc_video_loss = DiscriminatorLoss(video_disc_output_for_real, video_disc_output_for_gen)
        disc_image_loss = DiscriminatorLoss(image_disc_output_for_real[1:-1], image_disc_output_for_gen[1:-1])

    ## Calculate and apply gradients.
    gradients_of_generator = gen_tape.gradient(gen_loss, latent_representation_generator.trainable_variables)
    gradients_of_video_discriminator = video_disc_tape.gradient(disc_video_loss, video_discriminator.trainable_variables)
    gradients_of_image_discriminator = image_disc_tape.gradient(disc_image_loss, image_discriminator.trainable_variables)
        
    generator_optimizer.apply_gradients(
        zip(gradients_of_generator, latent_representation_generator.trainable_variables))
    video_discriminator_optimizer.apply_gradients(
        zip(gradients_of_video_discriminator, video_discriminator.trainable_variables))
    image_discriminator_optimizer.apply_gradients(
        zip(gradients_of_image_discriminator, image_discriminator.trainable_variables))
    
    ## Record loss graph.
    with summary_writer.as_default():
        tf.summary.scalar("gen_loss", gen_loss, step = epoch)
        tf.summary.scalar("disc_video_loss", disc_video_loss, step = epoch)
        tf.summary.scalar("disc_image_loss", disc_image_loss, step = epoch)

In [90]:
def train(
    dataset, 
    epochs = HPARAMS.epochs,
):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch, epoch)

        ## Save model every epochs.
        checkpoint.save(file_prefix = checkpoint_prefix)

        ## Display training times of each epochs.
        print (f"Time for epoch {epoch + 1} is {time.time() - start:.2f} sec")

In [91]:
%%time
## Dummy training dataset with 100 videos.
dummy_tr_tensor = tf.random.uniform(shape = [100] + HPARAMS.video_sz, maxval = 1.)
dummy_tr_dataset = tf.data.Dataset.from_tensor_slices(dummy_tr_tensor) \
                            .batch(HPARAMS.batch_sz) \
                            .cache() \
                            .prefetch(-1)

train(dummy_tr_dataset)

Time for epoch 1 is 103.76 sec
Time for epoch 2 is 80.82 sec
Time for epoch 3 is 80.79 sec
Time for epoch 4 is 80.33 sec
Time for epoch 5 is 80.09 sec
CPU times: user 6min 54s, sys: 13.1 s, total: 7min 7s
Wall time: 7min 5s


In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/fit