In [5]:
import tensorflow as tf
from keras.layers import Dense, LSTM, Activation, Bidirectional, Dropout
from keras.models import Sequential
import keras.backend as K

In [None]:
def squareError(xTrue, xPred):
    return K.square(xTrue - xPred)


In [None]:
def discriminateEncodedError(f_w, encoder, sample): # log f_w(E_theta_i(x_j)) from the paper.
    return K.log(f_w(encoder(sample)))


In [None]:
def reconstructionLoss(sample, encoder, decoder, f_w, weight): # (L_1 from the paper)
    return K.mean(squareError(sample, decoder(encoder(sample))) + 
                  weight*discriminateEncodedError(f_w, encoder, sample), axis=0)


In [None]:
def divergenceLoss(f_w, encoder, sample, z_j, n_j): # Mean of log f_w(E_theta_i(x_j)) + log (1-f_w(z_j, n_j)) from the paper (L_2).
    return K.mean(discriminateEncodedError(f_w, encoder, sample) + K.log(1 - f_w(z_j, n_j)), axis=0)


In [None]:
# takes in two inputs, n and z, and outputs samples.
def CreateDecoder():
    pass

In [None]:
def CreateEncoder(input_num, shared_output_num, remaining_output_num, hyperparams):
    # TODO MAYBE: Add in more regularization or different than dropout?
    # TODO make two outputs: n and z.
    model = Sequential([
        Bidirectional(LSTM(32, activation='tanh', return_sequences=True, input_shape=(batch_size, time_steps, features))),
        Dropout(0.2),
        Bidirectional(LSTM(32, activation='tanh', dropout=0.2, return_sequences=False)),
        Dropout(0.5),
        Dense(output_num)
    ])
    
    
    
    model.compile()
    return model

In [None]:
enc_optimizer = tf.keras.optimizers.Adam(5e-4)
dec_optimizer = tf.keras.optimizers.Adam(5e-4)
disc_optimizer = tf.keras.optimizers.Adam(5e-4)

In [None]:
# k is num of domains.
# encoders is a list of encoders.
# decoders is list of decoders.
# samples is a K x N array of samples, where the first index is the domain,
# the second index is the # of the sample in that domain.

# TODO IMPORTANT: Currently assuming P_Z is known, but it is NOT. Must alter algorithm as in (3.2) to support unknown P_Z.
def trainAutoencoders(k, encoders, decoders, samples, discriminator, weight=1.0):
    N = samples.shape[0]
    
    for i in range(k):
        encoder = encoders[i]
        decoder = decoders[i]
        while(not isConverged(encoders[i], decoders[i])):
            p_Xi_samples = samples[i,:]
            p_Z_samples = projectZ(encoders[i](samples[]))
            p_Ni_samples = None # TODO Something!
            
            with tf.GradientTape() as enc_tape, tf.GradientTape() as dec_tape, tf.GradientTape() as disc_tape:
                
                reconstruction_loss = reconstructionLoss(p_Xi_samples, encoder, decoder, f_w, weight)
                
                # negative b/c gradient ascent.
                divergence_loss = -divergenceLoss(f_w, encoder, p_Xi_samples, p_Z_samples, p_Ni_samples)

            gradients_of_encoder = enc_tape.gradient(reconstruction_loss, encoder.trainable_variables)
            gradients_of_decoder = dec_tape.gradient(reconstruction_loss, decoder.trainable_variables)
            gradients_of_discriminator = disc_tape.gradient(divergence_loss, discriminator.trainable_variables)
            

            enc_optimizer.apply_gradients(zip(gradients_of_encoder, encoder.trainable_variables))
            dec_optimizer.apply_gradients(zip(gradients_of_decoder, decoder.trainable_variables))
            disc_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
            