In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

class RealNVP(layers.Layer):
	def __init__(self, num_coupling_layers, input_dim, condition_dim):
		super(RealNVP, self).__init__()
		self.num_coupling_layers = num_coupling_layers
		self.input_dim = input_dim
		self.condition_dim = condition_dim
		self.masks = [self._create_mask(i) for i in range(num_coupling_layers)]
		self.s_t_networks = [self._create_s_t_network() for _ in range(num_coupling_layers)]

	def _create_mask(self, layer_index):
		mask = tf.range(self.input_dim) % 2
		if layer_index % 2 == 0:
			return tf.cast(mask, tf.float32)
		else:
			return 1.0 - tf.cast(mask, tf.float32)

	def _create_s_t_network(self):
		return tf.keras.Sequential([
			layers.InputLayer(input_shape=(self.input_dim + self.condition_dim,)),
			layers.Dense(128, activation='relu'),
			layers.Dense(128, activation='relu'),
			layers.Dense(2 * self.input_dim)
		])

	def call(self, x, condition, reverse=False):
		log_det_jacobian = 0.0
		if reverse:
			for i in reversed(range(self.num_coupling_layers)):
				x, log_det = self._inverse_coupling_layer(x, condition, i)
				log_det_jacobian += log_det
		else:
			for i in range(self.num_coupling_layers):
				x, log_det = self._forward_coupling_layer(x, condition, i)
				log_det_jacobian += log_det
		return x, log_det_jacobian

	def _forward_coupling_layer(self, x, condition, layer_index):
		mask = self.masks[layer_index]
		x_masked = x * mask
		condition_input = tf.concat([x_masked, condition], axis=-1)
		s_t = self.s_t_networks[layer_index](condition_input)
		s, t = tf.split(s_t, 2, axis=-1)
		s = tf.tanh(s)
		y = x_masked + (1 - mask) * (x * tf.exp(s) + t)
		log_det_jacobian = tf.reduce_sum((1 - mask) * s, axis=-1)
		return y, log_det_jacobian

	def _inverse_coupling_layer(self, y, condition, layer_index):
		mask = self.masks[layer_index]
		y_masked = y * mask
		condition_input = tf.concat([y_masked, condition], axis=-1)
		s_t = self.s_t_networks[layer_index](condition_input)
		s, t = tf.split(s_t, 2, axis=-1)
		s = tf.tanh(s)
		x = y_masked + (1 - mask) * (y - t) * tf.exp(-s)
		log_det_jacobian = -tf.reduce_sum((1 - mask) * s, axis=-1)
		return x, log_det_jacobian

class ConditionalNormalizingFlow(Model):
	def __init__(self, num_coupling_layers, input_dim, condition_dim):
		super(ConditionalNormalizingFlow, self).__init__()
		self.real_nvp = RealNVP(num_coupling_layers, input_dim, condition_dim)
		self.base_distribution = tfp.distributions.MultivariateNormalDiag(
			loc=tf.zeros(input_dim), scale_diag=tf.ones(input_dim))

	def call(self, x, condition):
		z, log_det_jacobian = self.real_nvp(x, condition)
		log_prob = self.base_distribution.log_prob(z) + log_det_jacobian
		return log_prob

	def sample(self, num_samples, condition):
		z = self.base_distribution.sample(num_samples)
		x, _ = self.real_nvp(z, condition, reverse=True)
		return x