In [None]:
import os
import sys
import time
import pickle

import numpy as np
import xlsxwriter
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import data_loader

from utils import *
from tensorflow import keras
from tensorflow.keras.layers import Layer, Dense, BatchNormalization, ReLU, Conv2D, Reshape
from tensorflow.keras import Model , regularizers

tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras
tfkd= tf.keras.datasets
tf.keras.backend.set_floatx('float32')
os.makedirs('./ring_samples_sigmoid/100_set', exist_ok=True)

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[2],'GPU')

In [None]:
num_classes = 32
temp_val = np.linspace(0.05,2.05,num_classes)

xy_data = np.float32(data_loader.load_data_mh_generated('./data/8x8_gibbslattices.pkl'))/(2*np.pi)

trainset  = []
testset   = []

# Earlier experiment were on 5120 samples and 2048 samples set for validation set.
# Second set experiments were conducted on 1024 samples and 1024 samples set for validation set.

for i in range(num_classes):
    trainset.append(xy_data[10000*i:10000*i+5000:50])
    testset.append(xy_data[10000*i+9000:10000*(i+1):10])

trainset = np.reshape(np.array(trainset),(-1,8,8,1))
testset  = np.reshape(np.array(testset),(-1,8,8,1))

train_temp = np.repeat(temp_val,100)
test_temp  = np.repeat(temp_val,100)

batch_size = 256

train_T    = tf.cast(np.repeat(train_temp,8*8).reshape(-1,8,8,1),dtype = tf.float32)
test_T     = tf.cast(np.repeat(test_temp,8*8).reshape(-1,8,8,1),dtype = tf.float32)
train_temp = tf.cast(train_temp.reshape(-1,),dtype=tf.float32)



training_dataset = tf.data.Dataset.from_tensor_slices((trainset,train_T,train_temp))
training_dataset = training_dataset.shuffle(buffer_size = 1024).batch(batch_size)
    
test_dataset = tf.data.Dataset.from_tensor_slices((testset,test_T))
test_dataset = test_dataset.batch(batch_size)
    


In [None]:
testset.shape

In [None]:
def checkerboard(height, width, reverse=False, dtype=tf.float32):
    checkerboard = [[((i % 2) + j) % 2 for j in range(width)] for i in range(height)] 
    checkerboard = tf.convert_to_tensor(checkerboard, dtype = dtype)
    if reverse:
        checkerboard = 1 - checkerboard
    
    checkerboard = tf.reshape(checkerboard, (1,height,width,1))
        
    return tf.cast(checkerboard, dtype=dtype)


In [None]:
filters = 64

def Coupling(input_shape):
    input1 = keras.layers.Input(shape=input_shape)
    input2 = keras.layers.Input(shape=input_shape)
    input  = tf.concat([input1,input2],axis=-1)
    
    layer1 = keras.layers.Conv2D(filters,3, activation="relu",padding = 'valid',name = 'layer1')(periodic_padding(input,1))
    layer2 = keras.layers.Conv2D(filters,3, activation="relu",padding = 'valid',name = 'layer2')(periodic_padding(layer1,1))
    t_layer= keras.layers.Conv2D(1,3,padding = 'valid',name = 't_layer')(periodic_padding(layer2))
    s_layer= keras.layers.Conv2D(1,3,activation = 'tanh',padding = 'valid', name = 's_layer')(periodic_padding(layer2))

    return keras.Model(inputs=[input1,input2], outputs=[s_layer, t_layer])

In [None]:
class SimpleNormal:
    def __init__(self, loc, var):
        self.dist = tfd.Normal(tf.reshape(loc,(-1,)), tf.reshape(var,(-1,)))
        self.shape = loc.shape
    def log_prob(self, x):
        logp = self.dist.log_prob(tf.reshape(x,(x.shape[0], -1)))
        return tf.reduce_sum(logp, axis=1)
    def sample_n(self, batch_size , seed = None):
        x = self.dist.sample((batch_size,),seed = seed)
        return tf.reshape(x,(batch_size, *self.shape))

In [None]:
def xy_action(lattice):
    xp = tf.roll(lattice,shift=1, axis=2)
    xn = tf.roll(lattice,shift=-1, axis=2)
    yp = tf.roll(lattice,shift=[0,1,0], axis=[0,1,2])
    yn = tf.roll(lattice,shift=[0,-1,0], axis=[0,1,2])
    H_matrix = -1*(tf.math.cos(2*np.pi*(xp-lattice)) + tf.math.cos(2*np.pi*(xn-lattice))+ tf.math.cos(2*np.pi*(yp-lattice)) + tf.math.cos(2*np.pi*(yn-lattice)))
    energy  = tf.reduce_sum((H_matrix),axis =[1,2])/2
    return energy

In [None]:
class RealNVP(keras.Model):
    def __init__(self, num_coupling_layers,input_shape , data_constraint):
        super().__init__()

        self.num_coupling_layers = num_coupling_layers
        self.distribution = SimpleNormal(tf.zeros((8,8)), tf.ones((8,8)))
        self.masks = [checkerboard(input_shape[0],input_shape[1], reverse=False),checkerboard(input_shape[0],input_shape[1], reverse=True)]*(num_coupling_layers // 2)
        
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.layers_list = [Coupling(input_shape) for i in range(num_coupling_layers)]
        self.data_constraint = data_constraint



    def call(self, x, forward=True):
        if forward:
            (x1,x2) = x 
            alpha = tf.constant(self.data_constraint)
            logq = self.distribution.log_prob(tf.reshape(x1,(-1,8,8)))
            x,ldj1 = self.forward(x)
            (x1,x2)= x
            ldj2 = tf.math.softplus(x1) + tf.math.softplus(-x1) + tf.math.log(tf.constant(1.-2*self.data_constraint)) 
            ldj2 = tf.reduce_sum(ldj2,[1,2,3])
            logq = logq - ldj1 + ldj2
            x1   = (tf.math.sigmoid(x1) - alpha)/(1-2*alpha)
            x = (x1,x2)
            return x, logq
        else:
            (x1,x2) = x
            x1 = self.data_constraint + (1-2*self.data_constraint)*x1
            x1 = tf.math.log(x1/(1.-x1)) 
            # Save log-determinant of Jacobian of initial transform
            ldj1 = tf.math.softplus(x1) + tf.math.softplus(-x1) + tf.math.log(tf.constant(1.-2*self.data_constraint))
            ldj1 = tf.reduce_sum(ldj1,[1,2,3])
            x = (x1,x2)
            x,ldj2 = self.reverse(x)
            (x1,x2) = x
            logq = self.distribution.log_prob(tf.reshape(x1,(-1,8,8)))
            logq = logq + ldj2 + ldj1
            return x , logq   
            
    


    def forward(self, x):
        ldj = 0
        for i in range(self.num_coupling_layers):
            (x1,x2) = x
            
            x_frozen = x1 * self.masks[i]
            reversed_mask = 1 - self.masks[i]
            x_active = x1 * reversed_mask
            s, t = self.layers_list[i]([x_frozen,x2])
            s *= reversed_mask
            t *= reversed_mask
            
            fx1 = t + x_active *tf.exp(s) + x_frozen
            fx2 = x2
            ldj += tf.reduce_sum(s, [1,2,3])
            x = (fx1,fx2)
        return x, ldj
              
    def reverse(self, fx):
        ldj = 0
        for i in reversed(range(self.num_coupling_layers)):
            (fx1,fx2)= fx
            fx_frozen = fx1*self.masks[i]
            reversed_mask = 1 - self.masks[i]
            fx_active = fx1*reversed_mask
            s, t = self.layers_list[i]([fx_frozen,fx2])
            s *= reversed_mask
            t *= reversed_mask
            
            x1 = (fx_active - t) *tf.exp(-s) + fx_frozen
            x2 = fx2
            ldj -= tf.reduce_sum(s, [1,2,3])
            fx = (x1,x2)
        return fx,ldj
        
     
    
   

In [None]:
model = RealNVP(24,(8,8,1),1.e-4)

In [None]:
boundaries = [1250*1, 1250*2, 1250*3, 1250*4, 1250*5, 1250*6, 1250*7, 1250*8, 1250*9, 1250*10]
 
gen_lr = 5.e-5
decay  = 0.95

values_gen = [gen_lr, gen_lr*decay, gen_lr*(decay)**2, gen_lr*(decay)**3, gen_lr*(decay)**4,
              gen_lr*(decay)**5, gen_lr*(decay)**6, gen_lr*(decay)**7, gen_lr*(decay)**8, 
              gen_lr*(decay)**9, gen_lr*(decay)**10]

learning_rate_fn_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values_gen)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn_gen)

train_FKL_loss= tf.keras.metrics.Mean('train_FKL_loss',dtype=tf.float32)
train_RKL_loss= tf.keras.metrics.Mean('train_RKL_loss',dtype=tf.float32)
train_total_loss= tf.keras.metrics.Mean('train_total_loss',dtype=tf.float32)

nll_loss  = tf.keras.metrics.Mean('nll_loss',dtype = tf.float32)

In [None]:
checkpoint_directory = './ring_samples_sigmoid/100_set/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt-10-2-6")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,generator = model)

In [None]:
checkpoint.restore(checkpoint_prefix)

In [None]:
@tf.function
def train_step(x,y,T,model):
    lambda1 = 0.5
    lambda2 = 1.0
    
    with tf.GradientTape() as gen_tape:
        z  = tf.reshape(model.distribution.sample_n(y.shape[0]),(-1,8,8,1))
        fx , logq = model((z,y),forward = True)
        (fx1,fx2) = fx
        logp = -xy_action(tf.reshape(fx1,(-1,8,8)))/T
        reverse_loss = tf.reduce_mean(logq -logp) # Reverse Kl loss
        x  , logq = model((x,y),forward = False)
        forward_loss = -tf.reduce_mean(logq)  # Forward Kl loss
        
        total_loss = lambda1 * reverse_loss + lambda2 * forward_loss
        
    gradients = gen_tape.gradient(total_loss,model.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    train_FKL_loss(forward_loss)
    train_RKL_loss(reverse_loss)
    train_total_loss(total_loss)

In [None]:
@tf.function
def test_step(x,y,model,forward = False):
    x , logq = model((x,y),forward = False)
    nll = -tf.reduce_mean(logq)  # Forward Kl loss
    nll_loss(nll)

In [None]:
start = time.time()
for epoch in range(3):
    
    for step,(x,y,T) in enumerate(training_dataset):
        train_step(x,y,T,model)
    
    #template = 'Epoch {:2d}, Gen_total_loss: {:.6f},Gen_Loss: {:.6f},kl_loss: {:.6f}, Disc_Loss: {:.6f}'
    #print (template.format(epoch+1,train_G_net_loss.result(),train_G_loss.result(),train_KL_loss.result(),train_D_loss.result()))

    for step,(x,y) in enumerate(test_dataset):
        test_step(x,y,model,forward = False)
        
    template = 'Epoch {:3d},total_loss: {:.6f},forward_loss: {:.6f},reverse_loss: {:.6f},test_loss: {:.6f}'
    print (template.format(epoch+1,train_total_loss.result(),train_FKL_loss.result(),train_RKL_loss.result(),nll_loss.result()))

    # Reset metrics every epoch
    
    train_FKL_loss.reset_states()
    train_RKL_loss.reset_states()
    train_total_loss.reset_states()
    nll_loss.reset_states()
       
        
stop = time.time()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
print('Time: ', stop - start)
            

In [None]:
checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# Sampling from Model
samples = []
for i in range(num_classes):
    cond = temp_val[i]*tf.ones(shape = [1000,8,8,1],dtype=tf.float32)
    z = tf.reshape(model.distribution.sample_n(1000,seed = 1000*i),(-1,8,8,1))
    x,logq = model((z,cond),forward = True)
    samples.append(x[0])
samples = np.array(samples).reshape((-1,8,8,1))
print(samples.shape)  

In [None]:
#Collect only 1000 samples for comparison 
#size of lattices = 8x8
#MCMC Generated lattices has 10000 samples against each lattice.

lat = []
for i in range(32):
    lat.append(xy_data[10000*i+4500:10000*i+5500])
    
mcmc_samples = np.array(lat).reshape((-1,8,8,1))


In [None]:
# f = open('./cgan/igan.pkl', 'rb')
# cgan_data = pickle.load(f,encoding="latin1")
# cgan_data = np.float32(np.array(cgan_data).reshape((-1,8,8,2)))

In [None]:
# cgan_data = np.reshape(np.arctan2(cgan_data[:,:,:,1],cgan_data[:,:,:,0])/(2*np.pi),(-1,8,8,1))
# comparison_plot(mcmc_samples,cgan_data,1000,0.05,2.05,32,J=1,K=1,name='./ring_samples_sigmoid/igan')

In [None]:
comparison_plot(mcmc_samples,samples,1000,0.05,2.05,32,J=1,K=1,name='./ring_samples_sigmoid/100_set/model_comparison_plot_ckpt_9')

In [None]:
evaluation_metrics(mcmc_samples,samples,1,1)

In [None]:
def serial_sample_generator(model,temp, action, batch_size, N_samples, seed = 500):
    x1, logq, logp = None, None, None
       
    for i in range(N_samples):
        batch_i = i % batch_size
        if batch_i == 0:
        # we're out of samples to propose, generate a new batch
            #seed = (np.rint(seed + batch_i + 1000*temp)).astype(int)
            #print(seed.dtype)
            z = tf.reshape(model.distribution.sample_n(batch_size,seed = seed + batch_i),(-1,8,8,1))
            cond = temp*tf.ones((batch_size,8,8,1))
            x, logq = model((z,cond),forward = True)
            (x1,x2) = x
            logp = -action(x1)/temp
        yield x1[batch_i], logq[batch_i], logp[batch_i]
        
def make_mcmc_ensemble(model, action, batch_size, N_samples,seed):
    rs = np.random.RandomState(seed=1000)
    history_for_all_temp = {'x' : [],'logq' : [],'logp' : [],'accepted' : [],'logw' : []}
    # build Markov chain
    temp_val = np.linspace(0.05,2.05,num_classes)
    for i,temp in enumerate(temp_val):
        history = {'x' : [],'logq' : [],'logp' : [],'accepted' : [],'logw' : []}
        sample_gen = serial_sample_generator(model,temp, action, batch_size, N_samples,seed = seed)
        
        for new_x, new_logq, new_logp in sample_gen:
            if len(history['logp']) == 0:
                # always accept first proposal, Markov chain must start somewhere
                accepted = True
            else:
                # Metropolis acceptance condition
                last_logp = history['logp'][-1]
                last_logq = history['logq'][-1]
                p_accept = tf.math.exp((new_logp - new_logq) - (last_logp - last_logq))
                p_accept = min(1, p_accept)
                draw = rs.rand() # ~ [0,1]
                if draw < p_accept:
                    accepted = True
                    
                else:
                    accepted = False
                    new_x = history['x'][-1]
                    new_logp = last_logp
                    new_logq = last_logq
            # Update Markov chain
            history['logp'].append(new_logp)
            history['logq'].append(new_logq)
            history['x'].append(new_x)
            history['accepted'].append(accepted)
            history['logw'].append(new_logp - new_logq)
            history_for_all_temp['logp'].append(new_logp)
            history_for_all_temp['logq'].append(new_logq)
            history_for_all_temp['x'].append(new_x)
            history_for_all_temp['accepted'].append(accepted)
            history_for_all_temp['logw'].append(new_logp - new_logq)
    return history_for_all_temp

ensemble_size = 1000
xy_ens = make_mcmc_ensemble(model, xy_action, 256, ensemble_size , seed = 1000)
print("Accept rate:", 100*np.mean(xy_ens['accepted']))



In [None]:
mh_samples  = np.array(xy_ens['x'])

In [None]:
mh_samples.shape

In [None]:
evaluation_metrics(mcmc_samples,mh_samples,1,1)

In [None]:
comparison_plot(mcmc_samples,mh_samples,1000,0.05,2.05,32,J=1,K=1,name='./ring_samples_sigmoid/100_set/mh_comparison_plot_ckpt_9')

In [None]:
def compute_ess(logp, logq):
    logw = logp - logq
    log_ess = 2*tf.math.reduce_logsumexp(logw, axis=0) - tf.math.reduce_logsumexp(2*logw, axis=0)
    ess_per_cfg = tf.math.exp(log_ess) / len(logw)
    return ess_per_cfg

In [None]:
ess_mh = []
AR     = []

for j in range(5):
    xy_ens = make_mcmc_ensemble(model, xy_action , 256,1000,seed = 1000*j) 
#     logp_acc = np.array(mog_ens['logp']).reshape((-1,))
#     logq_acc = np.array(mog_ens['logq']).reshape((-1,))
#     ess_mh.append(compute_ess(logp_acc,logq_acc).numpy())
    logw = np.array(xy_ens['logw']).reshape((-1,))
    log_ess = 2*tf.math.reduce_logsumexp(logw, axis=0) - tf.math.reduce_logsumexp(2*logw, axis=0)
    ess_per_cfg = tf.math.exp(log_ess) / len(logw)
    ess_mh.append(ess_per_cfg.numpy())
    AR.append(100*np.mean(xy_ens['accepted']))

print('ESS-MH:',ess_mh)
print('AR :',AR)
print('ESS-MH Mean : ',np.mean(np.array(ess_mh)))
print('AR Mean :',np.mean(np.array(AR)))

In [None]:
ess = []
for i in range(5):
    samples = []
    logp_array = []
    logq_array = []
    for j in range(num_classes):
        cond = temp_val[i]*tf.ones(shape = [1000,8,8,1],dtype=tf.float32)
        z = tf.reshape(model.distribution.sample_n(1000,seed = 1000*i+j),(-1,8,8,1))
        x,logq = model((z,cond),forward = True)
        samples.append(x[0])
        logp = -xy_action(x[0])/temp_val[i]
        logp_array.append(logp)
        logq_array.append(logq)
        
    samples = np.array(samples).reshape((-1,8,8,1))
    logp = np.array(logp_array).reshape((-1,))
    logq = np.array(logq_array).reshape((-1,))
    ess.append(compute_ess(logp, logq).numpy()) 

print('ESS:',ess)
print('ESS Mean : ',np.mean(np.array(ess)))

In [None]:
test_data = []
for i in range(num_classes):
    test_data.append(xy_data[10000*i+9000:10000*(i+1)])

test_data  = np.reshape(np.array(test_data),(-1,8,8,1))
cond_temp  = np.repeat(temp_val,1000)


batch_size   = 1000

test_cond    = tf.cast(np.repeat(cond_temp,8*8).reshape(-1,8,8,1),dtype = tf.float32)
test_set     = tf.data.Dataset.from_tensor_slices((test_data,test_cond))
test_set     = test_set.batch(batch_size)

In [None]:
NLL = []
for step,(x,y) in enumerate(test_set):
    x , logq = model((x,y),forward = False)
    nll = -tf.reduce_mean(logq)
    NLL.append(nll.numpy())
NLL = np.array(NLL)
print('NLL Mean :',np.mean(NLL))
    