<a href="https://colab.research.google.com/github/jsk245/frame_interpolator/blob/main/frame_interpolation_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dm-haiku optax scikit-video tfds-nightly

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import haiku as hk
import jax
import optax
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import requests
import os
import random
import skvideo.io
from functools import partial
from typing import Any, NamedTuple

tf.enable_v2_behavior()
tf.config.set_visible_devices([], device_type='GPU')

print("JAX version {}".format(jax.__version__))
print("Haiku version {}".format(hk.__version__))
print("TF version {}".format(tf.__version__))

In [None]:
data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
## there's also a validation split for more data to train on
data, info = tfds.load(name="davis", data_dir=data_dir, split="train", with_info=True)

## extra data to train on
"""
config = tfds.download.DownloadConfig(verify_ssl=False)
data = tfds.load(name="ucf101", split="train", shuffle_files=True, data_dir=data_dir, download_and_prepare_kwargs={"download_config" : config})
"""

In [5]:
def make_dataset(data=data, shuffle_amount=2, batch_size=1):
  def _to_float(sample_frames):
    # Convert to floats in [0, 1].
    sample = tf.image.convert_image_dtype(sample_frames, tf.float32)
    # Scale the data to [-1, 1] to stabilize training.
    sample = 2.0 * sample - 1.0
    return sample
  def _preprocess(sample):
    sample_frames = sample["video"]["frames"]
    sample_frames = _to_float(sample_frames)
    return sample_frames

  ds = data
  ds = ds.map(map_func=_preprocess, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ## increasing shuffle amount is more proper but it takes up more space
  ds = ds.shuffle(shuffle_amount).repeat().batch(batch_size)
  return (iter(tfds.as_numpy(ds)))

In [6]:
def save(ckpt_dir: str, state) -> None:
 with open(os.path.join(ckpt_dir, "arrays.npy"), "wb") as f:
   for x in jax.tree_util.tree_leaves(state):
     np.save(f, x, allow_pickle=False)

 tree_struct = jax.tree_map(lambda t: 0, state)
 with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
   pickle.dump(tree_struct, f)

def restore(ckpt_dir):
 with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
   tree_struct = pickle.load(f)
 
 leaves, treedef = jax.tree_util.tree_flatten(tree_struct)
 with open(os.path.join(ckpt_dir, "arrays.npy"), "rb") as f:
   flat_state = [np.load(f) for _ in leaves]

 return jax.tree_util.tree_unflatten(treedef, flat_state)

In [7]:
## using more than two surrounding frames should generate better results
def process_data(sample_frames, vary_frame_distance=False, use_adjustments=False):
  ## maybe adjust the random cutoffs
  ## maybe add adjustments affecting color
  if use_adjustments: # will flip image left/right and play frames in reverse randomly
    is_tensor = False
    if random.random() < 0.25:
      sample_frames = tf.image.flip_up_down(sample_frames)
      is_tensor = True
    if random.random() < 0.25:
      sample_frames = tf.image.flip_left_right(sample_frames)
      is_tensor = True
    if is_tensor:
      sample_frames = tfds.as_numpy(sample_frames)
    if random.random() < 0.25:
      sample_frames = jnp.flip(sample_frames, axis=0)
  key = jax.random.PRNGKey(random.randint(0,1000))
  if vary_frame_distance: #will randomly choose to use frames from further away to predict one in the center
    ## maybe increase the spread if the videos trained on have high frame rates
    distance_spread = random.randint(1,3)
  else:
    distance_spread = 1
  idx = jax.random.randint(key, [1], distance_spread, sample_frames.shape[0]-distance_spread)
  idx = jnp.repeat(idx, 3)
  idx = idx + jnp.tile(jnp.arange(-1*distance_spread,2*distance_spread, distance_spread), 1)
  idx = jnp.clip(idx, a_min=0, a_max=sample_frames.shape[0]-1)
  idx = jax.device_get(idx)
  sample = sample_frames[idx, :]
  if use_adjustments and random.random() < 0.5: # for making random 256x256 crops of the image
    ## maybe allow smaller portions to be cropped. Right now the smallest crop possible always includes the middle 50% of pixels
    x1 = random.random()/4
    x2 = random.random()/4 + 0.75
    y1 = random.random()/4
    y2 = random.random()/4 + 0.75
    crops = jnp.reshape(jnp.tile(jnp.array((y1, x1, y2, x2)), 1*3), (1*3, 4))
  else:
    crops = jnp.reshape(jnp.tile(jnp.array((0, 0, 1, 1)), 1*3), (1*3, 4))
  sample = tf.image.crop_and_resize(
            sample,
            crops,
            jnp.arange(1*3),
            (256, 256)).numpy()
  sample = jnp.reshape(sample, [1,3] + list(sample.shape[1:]))
  goal_frames = sample[:,1,:,:,:]
  goal_frames = jnp.reshape(goal_frames, [1,1] + list(sample.shape[2:]))
  sample = jnp.delete(sample, 1, axis=1)
  return sample, goal_frames

def make_batch(dataset, vary_frame_distance=False, use_adjustments=False, batch_size=4, skip_prob=0.5):
  surrounding_frames = jnp.empty((0, 2, 256, 256, 3))
  goal_frames = jnp.empty((0, 1, 256, 256, 3))
  for i in range(batch_size):
    data = next(dataset)[0]
    rand_val = random.random()
    if rand_val > (1-skip_prob):
      data = next(dataset)[0]
      rand_val = random.random()
    next_surrounding, next_goal = process_data(data, vary_frame_distance, use_adjustments)
    surrounding_frames = jnp.concatenate([surrounding_frames, next_surrounding])
    goal_frames = jnp.concatenate([goal_frames, next_goal])
  return surrounding_frames, goal_frames

In [9]:
# useful if using attention
"""def positional_encoding(input_tensor):
  encoding = jnp.ones(input_tensor.shape)
  first_helper = encoding[:,:,:,:encoding.shape[-1]//3] * jnp.arange(encoding.shape[0])[:,None,None,None]
  second_helper = encoding[:,:,:,encoding.shape[-1]//3:2*encoding.shape[-1]//3] * jnp.arange(encoding.shape[1])[None,:,None,None]
  third_helper = encoding[:,:,:,2*encoding.shape[-1]//3:] * jnp.arange(encoding.shape[2])[None,None,:,None]
  encoding = (jnp.concatenate([first_helper, second_helper, third_helper], axis=-1))

  encoding_helper = jnp.ones(encoding.shape)
  first_helper = encoding_helper[:,:,:,:encoding.shape[-1]//3]
  second_helper = encoding_helper[:,:,:,encoding.shape[-1]//3:2*encoding.shape[-1]//3]
  third_helper = encoding_helper[:,:,:,2*encoding.shape[-1]//3:]
  first_helper = first_helper * jnp.repeat(jnp.arange((first_helper.shape[-1])//2+1), 2)[None,None,None,:first_helper.shape[-1]]
  second_helper = second_helper * jnp.repeat(jnp.arange((second_helper.shape[-1])//2+1), 2)[None,None,None,:second_helper.shape[-1]]
  third_helper = third_helper * jnp.repeat(jnp.arange((third_helper.shape[-1])//2+1), 2)[None,None,None,:third_helper.shape[-1]]
  encoding_helper = (jnp.concatenate([first_helper, second_helper, third_helper], axis=-1))
  encoding_helper = 10000 ** (encoding_helper * 6 / encoding_helper.shape[-1])

  encoding = encoding / encoding_helper
  encoding = encoding.at[:,:,:,::2].set(jnp.sin(encoding[:,:,:,::2]))
  encoding = encoding.at[:,:,:,1::2].set(jnp.cos(encoding[:,:,:,1::2]))
  return input_tensor + encoding"""

class FourierConv(hk.Module):
  def __init__(self, channels, conv_length, is_training, temporal=True, name=None):
    super(FourierConv, self).__init__(name=name)
    hidden_channels = channels//2
    self.conv1 = hk.ConvND(3, hidden_channels, 1, 1)
    self.bn1 = hk.BatchNorm(False, False, 0.9, cross_replica_axis="jax_vmap_fourier")

    if temporal:
      self.conv2 = hk.ConvND(3, channels, [2,conv_length,conv_length], 1)
    else:
      self.conv2 = hk.ConvND(3, channels, [1,conv_length,conv_length], 1)
    self.bn2 = hk.BatchNorm(False, False, 0.9, cross_replica_axis="jax_vmap_fourier")
    
    self.conv3 = hk.ConvND(3, channels, 1, 1)

    self.is_training = is_training

  def __call__(self, x):
    x = self.conv1(x)
    x = self.bn1(x, self.is_training)
    x = jax.nn.relu(x)
    x_res = x
    x = jnp.fft.rfftn(x, axes=(0,1,2))
    x = jnp.reshape(jnp.stack([x.real, x.imag], axis=-1), (x.shape[0], x.shape[1], x.shape[2], x.shape[3]*2))
    x = self.conv2(x)
    x = x=self.bn2(x, self.is_training)
    x = jax.nn.relu(x)
    x = jnp.reshape(x, (x.shape[0], x.shape[1], x.shape[2], x.shape[3]//2, 2))
    x = jax.lax.complex(x[:,:,:,:,0], x[:,:,:,:,1])
    x = jnp.fft.irfftn(x, axes=(0,1,2))
    x = x + x_res
    x = jax.nn.relu(self.conv3(x))
    return x

class FourierBlock(hk.Module):
  def __init__(self, channels, is_training, temporal=True, name=None):
    super(FourierBlock, self).__init__(name=name)
    self.channels = channels
    half = channels//2

    self.conv_local_1 = hk.Conv2D(half, 3, 1)
    self.fourier_local = FourierConv(half, 3, is_training, temporal)
    self.local_bn = hk.BatchNorm(False, False, 0.9, cross_replica_axis="jax_vmap_fourier")

    self.fourier_global_1 = hk.Conv3D(half, [2,3,3], 1)
    self.fourier_global_2 = FourierConv(half, 3, is_training, temporal)
    self.global_bn = hk.BatchNorm(False, False, 0.9, cross_replica_axis="jax_vmap_fourier")

    self.is_training = is_training

  def __call__(self, x):
    split = self.channels // 2
    local_side = x[:,:,:,0:split]
    global_side = x[:,:,:,split:]
    pure_local = self.conv_local_1(local_side)
    N, H, W, C = local_side.shape
    q1 = self.fourier_local(local_side[:,:H//2,:W//2,:])
    q2 = self.fourier_local(local_side[:,:H//2,W//2:,:])
    q3 = self.fourier_local(local_side[:,H//2:,:W//2,:])
    q4 = self.fourier_local(local_side[:,H//2:,W//2:,:])
    local_to_global = jnp.reshape(jnp.concatenate([q1, q2, q3, q4], axis=2), (N,H,W,C))
    local_to_global = jnp.concatenate([local_to_global[:,::2,:,:], local_to_global[:,1::2,:,:]], axis=1)

    global_to_local = self.fourier_global_1(global_side)
    ## replace this with a regular conv to test if the fourier is helping
    pure_global = self.fourier_global_2(global_side)
    local_side = pure_local + global_to_local
    global_side = local_to_global + pure_global
    local_side = jax.nn.relu(self.local_bn(local_side, self.is_training))
    global_side = jax.nn.relu(self.global_bn(global_side, self.is_training))
    x = jnp.reshape(jnp.stack([local_side, global_side], axis=-1), (x.shape[0], x.shape[1], x.shape[2], self.channels))
    ## dropout?
    return x

class MyChannelMatcher(hk.Module):

  def __init__(self, out_channels, name=None):
    super(MyChannelMatcher, self).__init__(name=name)
    self.out_channels = out_channels

  def __call__(self, x):
    N, num_frames, H, W, C = x.shape
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(num_frames*H*W*C))
    w = hk.get_parameter("w", shape=[1, 1, 1, C, self.out_channels], dtype=x.dtype, init=w_init)
    dn = jax.lax.conv_dimension_numbers(x.shape, w.shape, ('NDHWC', 'HWDIO', 'NDHWC'))
    out = jax.lax.conv_general_dilated(x,    # lhs = image tensor
                               w,  # rhs = conv kernel tensor
                               (1,1,1), # window strides
                               'SAME',  # padding mode
                               (1,1,1), # lhs/image dilation
                               (1,1,1), # rhs/kernel dilation
                               dn)      # dimension_numbers
    return out

class Padder(hk.Module):
  def __init__(self, name=None):
    super(Padder, self).__init__(name=name)

  def __call__(self, x, x_res_future):
    h_diff = x_res_future.shape[2] - x.shape[2]
    h_diff_1 = h_diff//2
    h_diff_2 = h_diff-h_diff_1
    w_diff = x_res_future.shape[3] - x.shape[3]
    w_diff_1 = w_diff//2
    w_diff_2 = w_diff-w_diff_1
    x = jnp.pad(x, ((0,0),(0,0),(h_diff_1,h_diff_2),(w_diff_1,w_diff_2),(0,0)))
    return x

class FConvBlock(hk.Module):
  def __init__(self, channels, is_training, name=None):
    super(FConvBlock, self).__init__(name=name)
    self.conv = MyChannelMatcher(channels)
    self.fourier_1 = FourierBlock(channels, is_training)
    self.ln_1 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)
    self.fourier_2 = FourierBlock(channels, is_training)
    self.ln_2 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)

  def __call__(self, x):
    x = jax.nn.relu(self.conv(x))
    fourier_1 = partial(self.fourier_1)
    x = self.ln_1(jax.vmap(fourier_1, axis_name="jax_vmap_fourier")(x=x) + x)
    
    fourier_2 = partial(self.fourier_2)
    x = self.ln_2(jax.vmap(fourier_2, axis_name="jax_vmap_fourier")(x=x) + x)

    return x

class Downsampler(hk.Module):
  def __init__(self, is_training, name=None):
    super(Downsampler, self).__init__(name=name)
    self.avg_pooler = hk.AvgPool(2, 2, "VALID", channel_axis=-1)

    self.fconv_block_downsample_1 = FConvBlock(32, is_training)
    self.fconv_block_downsample_2 = FConvBlock(64, is_training)
    self.fconv_block_downsample_3 = FConvBlock(128, is_training)


  def __call__(self, x):
    avg_pooler = partial(self.avg_pooler)

    original_image = x
    x = self.fconv_block_downsample_1(x)
    x_res1 = x

    x = jax.vmap(avg_pooler)(x)
    original_image = jax.image.resize(original_image, list(x.shape[:-1]) + [3], "bilinear")
    x = self.fconv_block_downsample_2(x)
    x2 = self.fconv_block_downsample_1(original_image)
    x_res2 = jax.numpy.concatenate([x, x2], axis=-1)

    x = jax.vmap(avg_pooler)(x)
    x2 = jax.vmap(avg_pooler)(x2)
    original_image = jax.image.resize(original_image, list(x.shape[:-1]) + [3], "bilinear")
    x = self.fconv_block_downsample_3(x)
    x2 = self.fconv_block_downsample_2(x2)
    x3 = self.fconv_block_downsample_1(original_image)
    original_image = None
    x_res3 = jax.numpy.concatenate([x, x2, x3], axis=-1)

    x = jax.vmap(avg_pooler)(x2)
    x2 = jax.vmap(avg_pooler)(x3)
    x = self.fconv_block_downsample_3(x)
    x2 = self.fconv_block_downsample_2(x2)
    x_res4 = jax.numpy.concatenate([x, x2], axis=-1)

    x = jax.vmap(avg_pooler)(x2)
    x = self.fconv_block_downsample_3(x)
    x_res5 = x

    return (x_res1, x_res2, x_res3, x_res4, x_res5)

class FConvBlockFlow(hk.Module):
  def __init__(self, channels, is_training, name=None):
    super(FConvBlockFlow, self).__init__(name=name)
    self.fourier_1_and_2 = FConvBlock(channels, is_training)
    self.fourier_3 = FourierBlock(channels, is_training)
    self.ln_3 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)

  def __call__(self, x):
    x = self.fourier_1_and_2(x)

    fourier_3 = partial(self.fourier_3)
    x = self.ln_3(jax.vmap(fourier_3, axis_name="jax_vmap_fourier")(x=x) + x)

    return x

class Combiner(hk.Module):
  def __init__(self, is_training, name=None):
    super(Combiner, self).__init__(name=name)
    self.fconv_block_flow_1 = FConvBlockFlow(32, is_training)
    self.fconv_block_flow_2 = FConvBlockFlow(96, is_training)
    self.fconv_block_flow_3 = FConvBlockFlow(224, is_training)

    self.channel_matcher_res_5 = MyChannelMatcher(224)
    self.channel_matcher_res_4 = MyChannelMatcher(224)
    self.channel_matcher_res_2 = MyChannelMatcher(96)
    self.channel_matcher_res_1 = MyChannelMatcher(32)

    self.ln_res_4 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)
    self.ln_res_3 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)
    self.ln_res_2 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)
    self.ln_res_1 = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)

    self.padder = Padder()

  def __call__(self, x_res1, x_res2, x_res3, x_res4, x_res5):
    x_res5 = jax.nn.relu(self.channel_matcher_res_5(x_res5))
    x_res5 = self.fconv_block_flow_3(x_res5)
    x = x_res5

    x_res4 = jax.nn.relu(self.channel_matcher_res_4(x_res4))
    x_res4 = self.fconv_block_flow_3(x_res4)
    x = self.padder(x, x_res4)
    x_res4 = self.ln_res_4(x_res4 + x)
    x = x_res4

    x_res3 = self.fconv_block_flow_3(x_res3)
    x = self.padder(x, x_res3)
    x_res3 = self.ln_res_3(x_res3 + x)
    x = x_res3

    x_res2 = self.fconv_block_flow_2(x_res2)
    x = self.padder(x, x_res2)
    x = self.channel_matcher_res_2(x)
    x_res2 = self.ln_res_2(x_res2 + x)
    x = x_res2

    x_res1 = self.fconv_block_flow_1(x_res1)
    x = self.padder(x, x_res1)
    x = self.channel_matcher_res_1(x)
    x_res1 = self.ln_res_1(x_res1 + x)

    return (x_res1, x_res2, x_res3, x_res4, x_res5)

class FConvBlockUpsample(hk.Module):
  def __init__(self, channels, is_training, name=None):
    super(FConvBlockUpsample, self).__init__(name=name)
    self.fourier_conv = FConvBlock(channels, is_training)

    self.upsample_conv = hk.ConvNDTranspose(2, channels, 2, stride=2)

  def __call__(self, x):
    x = self.fourier_conv(x)

    upsample = partial(self.upsample_conv)
    x = jax.nn.relu(jax.vmap(upsample)(x))

    return x

class Upsampler(hk.Module):
  def __init__(self, is_training, name=None):
    super(Upsampler, self).__init__(name=name)
    self.fconv_block_upsample_5 = FConvBlockUpsample(224, is_training)
    self.fconv_block_upsample_4 = FConvBlockUpsample(128, is_training)
    self.fconv_block_upsample_3 = FConvBlock(64, is_training)
    self.fconv_block_upsample_2 = FConvBlock(32, is_training)

    self.padder = Padder()

  def __call__(self, x_res1, x_res2, x_res3, x_res4, x_res5):
    x_res5 = jnp.maximum(x_res5[:,0,:,:,:], x_res5[:,1,:,:,:])[:,None,:,:,:]
    x_res5 = self.fconv_block_upsample_5(x_res5)
    x_res5 = self.padder(x_res5, x_res4)

    x_res4 = jnp.concatenate([jnp.maximum(x_res4[:,0,:,:,:], x_res4[:,1,:,:,:])[:,None,:,:,:], x_res5], axis=-1)
    x_res4 = self.fconv_block_upsample_4(x_res4)
    x_res4 = self.padder(x_res4, x_res3)

    x_res3 = jnp.concatenate([jnp.maximum(x_res3[:,0,:,:,:], x_res3[:,1,:,:,:])[:,None,:,:,:], x_res4], axis=-1)
    x_res3 = self.fconv_block_upsample_3(x_res3)
    x_res3 = self.padder(x_res3, x_res2)

    x_res2 = jnp.concatenate([jnp.maximum(x_res2[:,0,:,:,:], x_res2[:,1,:,:,:])[:,None,:,:,:], x_res3], axis=-1)
    x_res2 = self.fconv_block_upsample_2(x_res2)
    x_res2 = self.padder(x_res2, x_res1)

    x = jnp.concatenate([jnp.maximum(x_res1[:,0,:,:,:], x_res1[:,1,:,:,:])[:,None,:,:,:], x_res2], axis=-1)

    return x

class ImageGenerator(hk.Module):
  def __init__(self, is_training, name=None):
    super(ImageGenerator, self).__init__(name=name)

    self.downsampler = Downsampler(is_training)

    self.combiner = Combiner(is_training)

    self.upsampler = Upsampler(is_training)

    self.final_conv = hk.ConvNDTranspose(3, 3, [1,3,3], 1)

  def __call__(self, x):
    x_res1, x_res2, x_res3, x_res4, x_res5 = self.downsampler(x)

    x_res1, x_res2, x_res3, x_res4, x_res5 = self.combiner(x_res1, x_res2, x_res3, x_res4, x_res5)

    x = self.upsampler(x_res1, x_res2, x_res3, x_res4, x_res5)

    x = jnp.tanh(self.final_conv(x))
    return x

In [10]:
def tree_shape(xs):
  return jax.tree_util.tree_map(lambda x: x.shape, xs)

class InterpolatorState(NamedTuple):
  params: Any
  states: Any
  opt_state: Any

def MSSSIML1_loss_vectorized(image1, image2, sigmas=(0.5, 1., 2., 4., 8.), filter_size=11, C1=.01, C2=.03, alpha=0.84):
  # image1 and image2 are the same shape Nx1xHxWxC and are floats in the range [-1.0, 1.0]
  ## experiment with different alpha values (maybe adjust the alpha at a certain iteration?)
  ## L2 instead of L1 would be an interesting test that's easy to implement
  C1 = C1**2
  C2 = C2**2
  num_scale = len(sigmas)
  batch = image1.shape[0]
  channels = image1.shape[4]
  height = image1.shape[2]
  width = image1.shape[3]
  image1 = jnp.moveaxis(jnp.reshape(image1, (batch, height, width, channels)), 3, 1)
  image2 = jnp.moveaxis(jnp.reshape(image2, (batch, height, width, channels)), 3, 1)
  image1 = (image1+1)/2
  image2 = (image2+1)/2
  image1 = jnp.reshape(image1, (batch*channels, height, width, 1))
  image2 = jnp.reshape(image2, (batch*channels, height, width, 1))
  diff = jnp.abs(image2-image1)
  width = filter_size

  # initialize the gaussian filters based on the bottom size
  w = jnp.exp((-1.*jnp.arange(-(width//2), width//2+1)**2)[None,:] / (2*jnp.array(sigmas)**2)[:,None])
  w = w[:,:,None]@w[:,None,:]
  w = w / jnp.sum(w, axis=(1,2))[:,None,None]
  w = jnp.reshape(w, (num_scale, 1, width, width))
  w = jnp.transpose(w, axes=(3,2,1,0))

  w = jnp.transpose(w,[3,2,0,1])
  image1 = jnp.transpose(image1,[0,3,1,2])
  image2 = jnp.transpose(image2,[0,3,1,2])
  
  mux = jax.lax.conv(image1, w, (1,1), "SAME")
  muy = jax.lax.conv(image2, w, (1,1), "SAME")
  sigmax2 = jax.lax.conv(image1**2, w, (1,1), "SAME") - mux **2
  sigmay2 = jax.lax.conv(image2**2, w, (1,1), "SAME") - muy **2
  sigmaxy = jax.lax.conv(image1*image2, w, (1,1), "SAME") - mux * muy
  l = (2 * mux * muy + C1)/(mux ** 2 + muy **2 + C1)
  cs = (2 * sigmaxy + C2)/(sigmax2 + sigmay2 + C2)

  Pcs = jnp.prod(cs, axis=1)
  # this is the l1_loss weighted by the gaussian
  l1_loss = jnp.mean(jax.lax.conv(jnp.transpose(diff,[0,3,1,2]), w[-1,:,:,:][None,:,:,:], (1,1), "SAME"))
  
  return alpha * (1 - jnp.mean(l[:, -1, :, :] * Pcs)) + (1 - alpha) * l1_loss

class FrameInterpolator:

  def __init__(self, is_training):
  
    # Define the Haiku network transforms.
    # We don't use BatchNorm so we don't use `with_state`.
    self.gen_transform = hk.without_apply_rng(
        hk.transform_with_state(lambda *args: ImageGenerator(is_training)(*args)))
    
    # Build the optimizers.
    total_steps = 1800 ## Total Batches
    scheduler = optax.cosine_decay_schedule(1e-3, decay_steps=total_steps, alpha=0.95)

    # Combining gradient transforms using `optax.chain`.
    ## using SGD instead of adam might help for generalization
    self.optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
        optax.scale_by_adam(),  # Use the updates from adam.
        optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
        # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
        optax.scale(-1.0)
    )

  @partial(jax.jit, static_argnums=0)
  def initial_state(self, rng, surrounding_frames, goal_frames):
    """Returns the initial parameters and optimize states."""
    # Generate dummy latents for the generator.
    dummy_surrounding_frames, dummy_goal_frames = jnp.ones(surrounding_frames.shape), jnp.ones(goal_frames.shape)

    gen_params, gen_state = self.gen_transform.init(rng, dummy_surrounding_frames)
    print("Generator: \n\n{}\n".format(tree_shape(gen_params)))
    
    # Initialize the optimizers.
    gen_opt_state = self.optimizer.init(gen_params)
    
    return InterpolatorState(params=gen_params, states=gen_state, opt_state=gen_opt_state)

  def create_image(self, gen_params, gen_state, surrounding_frames):
    """Generates images from noise latents."""
    return self.gen_transform.apply(gen_params, gen_state, surrounding_frames)
    
  def gen_loss(self, gen_params, gen_state, surrounding_frames, goal_frames):
    """Generator loss."""
    # Sample from the generator.
    fake_batch, gen_state = self.create_image(gen_params, gen_state, surrounding_frames)

    # Evaluate using the discriminator. Recall class 1 is real.
    loss = MSSSIML1_loss_vectorized(fake_batch, goal_frames)
    #loss = L1_loss(fake_batch, goal_frames)
    return loss, (gen_state, fake_batch)

  @partial(jax.jit, static_argnums=0)
  def update_gen(self, interpolator_state, surrounding_frames, goal_frames):
    # Update the generator.
    (gen_loss, gen_loss_aux_output), gen_grads = jax.value_and_grad(self.gen_loss, has_aux=True)(
        interpolator_state.params,
        interpolator_state.states, 
        surrounding_frames,
        goal_frames)
    gen_update, gen_opt_state = self.optimizer.update(
        gen_grads, interpolator_state.opt_state)
    gen_params = optax.apply_updates(interpolator_state.params, gen_update)
    interpolator_state = InterpolatorState(params=gen_params, states=gen_loss_aux_output[0], opt_state=gen_opt_state)
    return interpolator_state, gen_loss, gen_loss_aux_output[1]

In [None]:
num_steps = 1200
log_every = 20

# Display hardware
print(f"Number of devices: {jax.device_count()}")
print("Device:", jax.devices()[0].device_kind)
print("")

# The training dataset
dataset = make_dataset(shuffle_amount=1)

# Top-level RNG.
rng = jax.random.PRNGKey(42)

losses = []

# dummy frames for init
## using a different batch size would likely help
surrounding_frames, goal_frames = make_batch(dataset, True, True, 1)

# The model.
interpolator = FrameInterpolator(is_training=True)

# Initialize the network and optimizer.
interpolator_state = interpolator.initial_state(rng, surrounding_frames, goal_frames)

# if previous model is already saved
#gen_params = restore("/content/gdrive/MyDrive/frame_interpolation/gen/params3")
#gen_state = restore("/content/gdrive/MyDrive/frame_interpolation/gen/model_state3")
#gen_opt_state = interpolator.optimizer.init(gen_params)
#interpolator_state = InterpolatorState(params=gen_params, states=gen_state, opt_state=gen_opt_state)

for step in range(0, num_steps+1):
  ## using a different batch size would likely help
  surrounding_frames, goal_frames = make_batch(dataset, True, True, 12)
  interpolator_state, interpolator_loss, images_generated = interpolator.update_gen(interpolator_state, surrounding_frames, goal_frames)
  losses.append(jax.device_get(interpolator_loss))
  # Log the losses.
  if step % log_every == 0:   
    # It's important to call `device_get` here so we don't take up device
    # memory by saving the losses.
    interpolator_loss = jnp.mean(jnp.array(losses))
    losses = []
    print(f"Step {step}: "
          f"train loss = {interpolator_loss:.7f}")


  if step % (3*log_every) == 0:
    #for visualizing one of the generated images
    images_generated = jax.device_get(images_generated[0])
    arr_ = np.squeeze((images_generated+1)/2)
    plt.imshow(arr_)
    plt.show()
    #goal for comparision:
    arr_ = np.squeeze((goal_frames[0,0,:,:,:]+1)/2)
    plt.imshow(arr_)
    plt.show()

  if step % (6*log_every) == 0:
    save("/content/gdrive/MyDrive/frame_interpolation/gen/params3", jax.device_get(interpolator_state.params))
    save("/content/gdrive/MyDrive/frame_interpolation/gen/model_state3", jax.device_get(interpolator_state.states))
    #for saving optimizer state (commented out because the file is >1GB)
    #with open(os.path.join("/content/gdrive/MyDrive/frame_interpolation/gen/opt_state", "opt_state.pkl"), "wb") as output_file:
    #  pickle.dump(jax.device_get(interpolator_state.opt_state), output_file)

In [None]:
interpolator = FrameInterpolator(is_training=False)
gen_params = restore("/content/gdrive/MyDrive/frame_interpolation/gen/params3")
gen_state = restore("/content/gdrive/MyDrive/frame_interpolation/gen/model_state3")

In [None]:
# This is how I roughly checked for L1 contribution to the loss
dataset = make_dataset(data, 1, 1)
total_loss = 0
for _ in range(60):
  surrounding_frames, goal_frames = make_batch(dataset, True, True, 1, skip_prob=0)
  pred_image, _ = interpolator.create_image(gen_params, gen_state, surrounding_frames, False)
  total_loss += jnp.mean(jnp.abs(pred_image - goal_frames))
print(total_loss/60)

In [None]:
#example of how to apply the model after training (will output vid in 256x256 pixels, so change the pixel sizes if using a different sized video):
videodata = jnp.array(skvideo.io.vread("video.mp4")) #change this to your video
final_frames = jnp.empty((0, 256, 256, 3))
for i in range(videodata.shape[0]-1):
  surrounding_frames = videodata[i:i+2,:,:,:]/255.*2-1
  crops = jnp.reshape(jnp.tile(jnp.array((0, 0, 1, 1)), 2), (2, 4))
  surrounding_frames = tf.image.crop_and_resize(
            surrounding_frames,
            crops,
            jnp.arange(2),
            (256, 256)).numpy()
  surrounding_frames = jnp.reshape(surrounding_frames, (1, 2, 256, 256, 3))
  pred_image, _ = interpolator.create_image(interpolator_state.params, interpolator_state.states, surrounding_frames, False)
  final_frames = jnp.concatenate([final_frames, surrounding_frames[0,0,:,:,:][None,:,:,:], pred_image[0]])
  if i == videodata.shape[0]-2:
    final_frames = jnp.concatenate([final_frames, surrounding_frames[0,1,:,:,:][None,:,:,:]])
outputdata = (final_frames + 1)/2 * 255
outputdata = outputdata.astype(jnp.uint8)

writer = skvideo.io.FFmpegWriter("outputvideo.mp4")
for i in range(outputdata.shape[0]):
  writer.writeFrame(outputdata[i, :, :, :])
writer.close()