## Model
- The model is of the form $y \rightarrow z \rightarrow x$
- $y$ is a categorical r.v. that is used to select parameters for $z$ which has a mixture of Gaussians distribution. 
- $x$ is binary-valued random variable that follows a Bernoulli distribution parameterised by $z$


In [1]:
import tensorflow as tf

In [None]:
class CURL(tf.keras.models.Model):
    def __init__(self, enc_y_x, enc_z_xy, decoder):
        super(CURL, self).__init__()
        self.enc_y_x = enc_y_x
        self.enc_z_xy = enc_z_xy
        self.decoder = decoder
    
    def __call__(self, x, noise):
        prob_y = self.enc_y_x(x)
        y = tf.stop_gradient(tf.argmax(prob_y))
        # For each value of y
        mu_z, sigma_z = self.enc_z_xy(x) 
        z = (noise + mu_z) * sigma_z
        x_hat = self.decoder(z)

Dynamic expansion

In [3]:
class Dnew(object):
    def __init__(self, threshold, max_size):
        self.threshold = threshold
        self.max_size = max_size
        self.samples = []
        
    def expand(self, samples):
        for sample in samples:
            if sample.log_likelihood < threshold:
                self.samples.append(sample)
                if len(self.samples) == max_size:
                    all_samples = self.samples
                    self.samples = []
                    return all_samples
        return

Mixture generative replay

In [24]:
def get_next_inputs(config, model, loader, itr):
    if not config.mgr or ((itr%2) == 0):
        return next(loader)
    return generate_imgs(config, model, itr)
    

def train(config, model, dataloader):
    loader = iter(dataloader)
    qy_x = tf.zeros(config.num_init_components)
    if config.dynamic_expansion:
        D_new = Dnew(config.threshold, config.max_size)
    steps = len(dataloader) * (2 if config.mgr else 1)
    for i in range(steps):
        inputs = get_next_inputs(config, model, loader, i)
        outputs = model(inputs)
        
        if config.mgr:
            qy_x = (qy_x * i + outputs.qy_x) / (i + 1)
            
        if config.dynamic_expansion:
            samples = D_new.expand(outputs.samples)
            
            if config.dynamic_snapshot:
                save_snapshot(config, model, i)
        
        if config.fixed_snapshot and ((i % config.fixed_snapshot_iter)==0):
            save_snapshot(config, model, i)
            
        
            # if samples is not None
                # add new component 
                # if config.mgr
                    # extend qy_x
                    