In [27]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

In [3]:
! nvidia-smi -L # list GPUs available

GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-0fa65b59-0cce-b2b8-7a13-88f718c609a3)
GPU 1: NVIDIA A100-SXM4-40GB (UUID: GPU-7dea5574-6b67-d7b8-4aad-cc2475290e80)
GPU 2: NVIDIA A100-SXM4-40GB (UUID: GPU-1f635c9d-7b9e-c8eb-0269-01c657ac9855)
GPU 3: NVIDIA A100-SXM4-40GB (UUID: GPU-60c99f33-4c24-fb2d-1465-14a922b43714)


In [4]:
class ContextMLP(tf.keras.Model):
    '''
    Standard MLP for mixing features with true data features via \sigma(a * \phi(x) + b)
    '''
    def __init__(self, dim_in, dim_out, dim_ctx):
        super(ContextMLP, self).__init__()
        self.layer = tf.keras.layers.Dense(dim_out)
        self.ctx_bias = tf.keras.layers.Dense(dim_out, use_bias = False)
        self.ctx_gate = tf.keras.layers.Dense(dim_out)
        
    '''
    Inputs should consist of [x, ctx] first the features and then the data features
    '''
    def call(self, inputs):
        x = inputs[0]
        ctx = inputs[1]
        
        gate = tf.sigmoid(self.ctx_gate(ctx))
        bias = self.ctx_bias(ctx)
        out = self.layer(x) * gate + bias
        return out

In [39]:
class VarianceScheduler():
    def __init__(self, num_steps, initial_beta, final_beta, mode = 'linear'):
        super().__init__()
        assert mode in ('linear',)
        self.num_steps = num_steps
        self.initial_beta = initial_beta
        self.final_beta = final_beta
        self.mode = mode
        
        if mode == 'linear':
            betas = tf.linspace(initial_beta, final_beta, num_steps)
            
        betas = tf.concat([tf.zeros([1]), betas], axis = 0) # Padding
        
        alphas = 1. - betas
        log_alphas = tf.math.log(alphas)
        for i in range(1, log_alphas.size): #[1, T]#
            #log_alphas[i] += log_alphas[i - 1]
            log_alphas[i].assign(log_alphas[i] + log_alphas[i - 1])
            
        alpha_bars = log_alphas.exp()
        
        sigma_flex = tf.math.sqrt(betas)
        sigmas_inflex = tf.zeros_like(sigmas_flex)
        for i in range(1, sigmas_flex.size):
            sigmas_inflex[i] = ((1. - alpha_bars[i - 1]) / (1. - alpha_bars[i])) * betas[i]
        sigmas_inflex = tf.math.sqrt(sigmas_inflex)
        
    def uniform_sample_t(self, batch_size):
        t_values = np.random_choice(np.arange(1, self.num_steps + 1), batch_size)
        return t_values.tolist()
    
    def get_sigmas(self, t, flexibility):
        assert 0 <= flexibility and flexbility <= 1
        sigmas = self.sigmas_flex[t] * flexbility + self.sigmas_inflex[t] * (1. - flexibility)
        return sigmas

In [24]:
class PointwiseNet(tf.keras.Model):
    def __init__(self, point_dim, context_dim, residual):
        super().__init__()
        self.act = tf.keras.layers.LeakyReLU
        self.residual = residual
        self.layer_list = [
            ContextMLP(3, 128, context_dim + 3),
            ContextMLP(128, 256, context_dim + 3),
            ContextMLP(256, 512, context_dim + 3),
            ContextMLP(512, 256, context_dim + 3),
            ContextMLP(256, 128, context_dim + 3),
            ContextMLP(128, 3, context_dim + 3)
        ]
        
    def call(self, inputs):
        """
        Args:
            x:  Point clouds at some timestep t, (B, N, d).
            beta:     Time. (B, ).
            context:  Shape latents. (B, F).
        """
        
        x = inputs[0]
        beta = inputs[1]
        context = inputs[2]
        
        batch_size = x.size(0)
        beta = tf.reshape(beta, [batch_size, 1, 1])  # (B, 1, 1)
        context = tf.reshape(context, [batch_size, 1, -1])  # (B, 1, F)

        time_emb = tf.concat([beta, tf.math.sin(beta), tf.math.cos(beta)], dim = -1)  # (B, 1, 3)
        ctx_emb = tf.concat([time_emb, context], dim = -1)  # (B, 1, F+3)

        out = x
        for i in range(len(self.layer_list)):
            layer = self.layer_list[i]
            out = layer(ctx = ctx_emb, x = out)
            if i < len(self.layer_list) - 1:
                out = self.act(out)

        if self.residual:
            return x + out
        else:
            return out

In [7]:
class DiffusionPoint():
    def __init__(self, net, var_sched: VarianceScheduler):
        #super().__init__()
        self.net = net
        self.var_sched = var_sched
        self.mse = tf.keras.losses.MeanSquaredError(reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

    def get_loss(self, x_0, context, t = None):
        """
        Args:
            x_0:  Input point cloud, (B, N, d).
            context:  Shape latent, (B, F).
        """
        batch_size, _, point_dim = x_0.size()
        if t is None:
            t = self.var_sched.uniform_sample_t(batch_size)
        alpha_bar = self.var_sched.alpha_bars[t]
        beta = self.var_sched.betas[t]

        c0 = tf.math.sqrt(alpha_bar).view(-1, 1, 1)  # (B, 1, 1)
        c1 = tf.math.sqrt(1 - alpha_bar).view(-1, 1, 1)  # (B, 1, 1)

        e_rand = tf.random.normal(x_0)  # (B, N, d)
        e_theta = self.net([c0 * x_0 + c1 * e_rand, beta, context])

        loss = self.mse(e_theta.view(-1, point_dim), e_rand.view(-1, point_dim))
        return loss

    def sample(self, num_points, context, point_dim = 3, flexibility = 0.0, ret_traj = False):
        batch_size = context.size(0)
        x_T = tf.random.normal([batch_size, num_points, point_dim]).to(context.device)
        traj = {self.var_sched.num_steps: x_T}
        for t in range(self.var_sched.num_steps, 0, -1):
            z = tf.random.normal(x_T) if t > 1 else tf.zeros_like(x_T)
            alpha = self.var_sched.alphas[t]
            alpha_bar = self.var_sched.alpha_bars[t]
            sigma = self.var_sched.get_sigmas(t, flexibility)

            c0 = 1.0 / tf.math.sqrt(alpha)
            c1 = (1 - alpha) / tf.math.sqrt(1 - alpha_bar)

            x_t = traj[t]
            beta = self.var_sched.betas[[t] * batch_size]
            e_theta = self.net([x_t, beta, context])
            x_next = c0 * (x_t - c1 * e_theta) + sigma * z
            traj[t - 1] = x_next.detach()  # Stop gradient and save trajectory.
            traj[t] = traj[t].cpu()  # Move previous output to CPU memory.
            if not ret_traj:
                del traj[t]

        if ret_traj:
            return traj
        else:
            return traj[0]

In [40]:
diffusion = DiffusionPoint(
            net = PointwiseNet(point_dim=3, context_dim=10, residual=True),
            var_sched = VarianceScheduler(
                num_steps=10,
                initial_beta=0.01,
                final_beta=0.1,
                mode='linear'
            )
        )

AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'assign'