In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import cache_util
from tensorflow_probability.python.internal import prefer_static
import numpy as np 

tfd = tfp.distributions
tfb = tfp.bijectors

In [None]:
# Copied and modified from tfb.FFJORD into a conditional one (taking extra inputs).
def trace_jacobian_hutchinson(
    ode_fn,
    state_shape,
    dtype,
    condition,
    sample_fn=tf.random.normal,
    num_samples=1,
    seed=None):
  random_samples = sample_fn(
      prefer_static.concat([[num_samples], state_shape], axis=0),
      dtype=dtype, seed=seed)

  def augmented_ode_fn(time, state_log_det_jac):
    state, _ = state_log_det_jac
    with tf.GradientTape(persistent=True,
                         watch_accessed_variables=False) as tape:
      tape.watch(state) # I'm not sure should condition be watched/required grad or not... 
      state_time_derivative = ode_fn(time, state, condition) # might be better putting condition into state.

    def estimate_trace(random_sample):
      jvp = tape.gradient(state_time_derivative, state, random_sample)
      return random_sample * jvp

    results = tf.map_fn(estimate_trace, random_samples)
    trace_estimates = tf.reduce_mean(results, axis=0)
    return state_time_derivative, trace_estimates

  return augmented_ode_fn

class FFJORD(bijector.Bijector):
  def __init__(
      self,
      state_time_derivative_fn,
      ode_solve_fn=None,
      trace_augmentation_fn=trace_jacobian_hutchinson,
      initial_time=0.,
      final_time=1.,
      validate_args=False,
      dtype=tf.float32,
      name='ffjord'):
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._initial_time = initial_time
      self._final_time = final_time
      self._ode_solve_fn = ode_solve_fn
      if self._ode_solve_fn is None:
        self._ode_solver = tfp_math.ode.DormandPrince()
        self._ode_solve_fn = self._ode_solver.solve
      self._trace_augmentation_fn = trace_augmentation_fn
      self._state_time_derivative_fn = state_time_derivative_fn

      def inverse_state_time_derivative(time, state, condition):
        return -state_time_derivative_fn(self._final_time - time, state, condition)

      self._inv_state_time_derivative_fn = inverse_state_time_derivative
      super(FFJORD, self).__init__(
          forward_min_event_ndims=0,
          dtype=dtype,
          validate_args=validate_args,
          parameters=parameters,
          name=name)

  def _solve_ode(self, ode_fn, state):
    integration_result = self._ode_solve_fn(
        ode_fn=ode_fn,
        initial_time=self._initial_time,
        initial_state=state,
        solution_times=[self._final_time])
    final_state = tf.nest.map_structure(
        lambda x: x[-1], integration_result.states)
    return final_state

  def _augmented_forward(self, x, cond):
    """Computes forward and forward_log_det_jacobian transformations."""
    augmented_ode_fn = self._trace_augmentation_fn(
        self._state_time_derivative_fn, x.shape, x.dtype, cond)
    augmented_x = (x, tf.zeros(shape=x.shape, dtype=x.dtype))
    y, fldj = self._solve_ode(augmented_ode_fn, augmented_x)
    return y, {'ildj': -fldj, 'fldj': fldj}

  def _augmented_inverse(self, y, cond):
    """Computes inverse and inverse_log_det_jacobian transformations."""
    augmented_inv_ode_fn = self._trace_augmentation_fn(
        self._inv_state_time_derivative_fn, y.shape, y.dtype, cond)
    augmented_y = (y, tf.zeros(shape=y.shape, dtype=y.dtype))
    x, ildj = self._solve_ode(augmented_inv_ode_fn, augmented_y)
    return x, {'ildj': ildj, 'fldj': -ildj}

  def _forward(self, x, cond):
    y, _ = self._augmented_forward(x, cond)
    return y

  def _inverse(self, y, cond):
    x, _ = self._augmented_inverse(y, cond)
    return x

  def _forward_log_det_jacobian(self, x, cond):
    return self._augmented_forward(x, cond)[1]['fldj']

  def _inverse_log_det_jacobian(self, y, cond):
    return self._augmented_inverse(y, cond)[1]['ildj']

In [None]:
# ODE function

class ConcatSquash(tf.keras.Model):
  def __init__(self, nhidden=512):
    super(ConcatSquash, self).__init__()
    self.dense = tf.keras.layers.Dense(nhidden)
    self._hyper_bias = tf.keras.layers.Dense(nhidden, use_bias=False)
    self._hyper_gate = tf.keras.layers.Dense(nhidden, activation='sigmoid')

  def call(self, t, x, cond):
    return self.dense(x) * self._hyper_gate(cond) + self._hyper_bias(cond)


class ODEFnc(tf.keras.Model):
  def __init__(self, nhidden=512, stack=4):
    super(ODEFnc, self).__init__()
    self.concatsquash = [ConcatSquash(nhidden=nhidden) for i in range(stack)]
  def call(self, t, x, cond):
    t = tf.broadcast_to(t, tf.shape(cond))
    cond = tf.concat([t, cond], axis=-1)
    for lyr in self.concatsquash:
      x = lyr(t, x, cond)
    return x

In [None]:
# Set up model, bijectors, distribution(styleflow) and ode solver.
mlp_model = ODEFnc()
solver = tfp.math.ode.DormandPrince(atol=1e-5)
ode_solve_fn = solver.solve
ffjord = FFJORD(
        state_time_derivative_fn=mlp_model,ode_solve_fn=ode_solve_fn,
        trace_augmentation_fn=trace_jacobian_hutchinson)
moving_batchnorm = tfb.Chain([tfb.BatchNormalization(), ffjord, tfb.BatchNormalization()])
base_loc = np.zeros((512,)).astype(np.float32)
base_sigma = np.ones((512,)).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)
styleflow = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=moving_batchnorm)

In [None]:
# here, an example show you how to pass condition to the model, yeah.
awesome_tensorflow_probability_extra_input_passing_kwargs_way_params_dict = {'bijector_kwargs': {'ffjord':{'cond': tf.random.normal((1, 2))}}}
styleflow.log_prob(styleflow.sample(1, **awesome_tensorflow_probability_extra_input_passing_kwargs_way_params_dict), **awesome_tensorflow_probability_extra_input_passing_kwargs_way_params_dict)

**Training**

In [None]:
BATCH_SIZE = 64
# I don't have any labels currently but you can simply use stylegan to generate some images(10k-20k according to the author),
# and save those images corresponding w with shape (512,). Meanwhile, you use those generated images to pass in some trained
# classifier like age regression model, gender classification model, or any kind of classification model to get each image's 
# corresponding labels.
# Let say all w and corresponding labels are loaded here...

def augment(w, cond):
    ## some data pre-process
    ## some labels might need to be processed to one-hot like. depends on what labels and how you want it to be looked like.
    ## one thing you need to remember is the label shape should be (batch_size, label_size) since the odefunc given above 
    ## adapted in this way. 
    return w, cond

train_data = tf.data.Dataset.from_tensor_slices((w, cond)).map(augment).batch(BATCH_SIZE)

In [None]:
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)

@tf.function()
def train_step(x):
    with tf.GradientTape() as tape:
        log_prob_loss = -styleflow.log_prob(x[0], **{'bijector_kwargs':{'ffjord':{'cond':x[1]}}})
    variables = tape.watched_variables()
    grads = tape.gradient(log_prob_loss, variables)
    optimizer.apply_gradients(zip(grads, variables))
    return tf.reduce_mean(log_prob_loss)


In [None]:
ckpt = tf.train.Checkpoint(model=styleflow)
ckpt_manager = tf.train.CheckpointManager(ckpt, './styleflow', max_to_keep=3)

In [None]:
for epoch in range(10):
    for i, x in enumerate(train_data):
        loss = train_step(x)
        print('epoch: ', epoch, 'iter: ', i, ' loss: ', loss.numpy(), '**', end='\r')
    ckpt_manager.save()


In [None]:
### trained!
### random sample with fixed condition. (sampling w with given labels, then you use this w to generate image with stylegan.)
w = styleflow.sample(1, **{'bijector_kwargs': {'ffjord':{'cond': labels}}})

### conditional editing. changing given person's attributes.
### first encode the w you want to change
code = styleflow.bijector.inverse(given_w, **{'ffjord':{'cond': w_corresponding_labels}})  

### second decode it with new labels
w_prime = styleflow.bijector.inverse(given_w, **{'ffjord':{'cond': new_labels}}) 

### finally use w_prime to generate image in stylegan with style mixing.