In [1]:
import sys
assert sys.version_info >= (3, 6), "Sonnet 2 requires Python >=3.6"

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree
import pandas as pd

try:
  import sonnet.v2 as snt
  tf.enable_v2_behavior()
except ImportError:
  import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))

TensorFlow version 2.3.1
Sonnet version 2.0.0


In [3]:
sys.path.append('../models')

In [4]:
import vae

In [107]:
import importlib
importlib.reload(vae)

<module 'vae' from '../models/vae.py'>

In [108]:
is_data = pd.read_table('/Users/tuckerkj/Google Drive/Research/QML/data/quc_examples/Tutorial1_TrainPosRealWaveFunction/tfim1d_data.txt', delimiter=' ', usecols=range(10)).values
is_train = is_data[0:7999]
is_test = is_data[8000:9999]

In [109]:
original_dim = is_train.shape[1]

In [110]:
# network parameters
learning_rate = 3e-4
input_shape = (original_dim, )
batch_size = 128
epochs = 50
depth = 2

# F = 0.8960
intermediate_dim = [100]
latent_dim = 2

# Train

In [120]:
# network parameters
learning_rate = 3e-4
input_shape = (original_dim, )
batch_size = 128
epochs = 50
depth = 2

# F = 0.9322
#intermediate_dim = [500, 500]
#latent_dim = 10

# F = 0.9255
#intermediate_dim = [100]
#latent_dim = 4

# F = 0.8989
intermediate_dim = [100]
latent_dim = 2

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

enc = vae.CatEncoder(intermediate_dim, latent_dim, depth)
dec = vae.CatDecoder(intermediate_dim, original_dim, depth)
catvae = vae.CatVAE(enc, dec)

@tf.function
def train_step(data):
    with tf.GradientTape() as tape:
        model_output = catvae(tf.cast(data, tf.int32))
    
    trainable_variables = catvae.trainable_variables
    grads = tape.gradient(model_output['loss'], trainable_variables)
    optimizer.apply(grads, trainable_variables)
    
    return model_output

In [121]:
# Get data sliced for SGD
train_dataset = (
    tf.data.Dataset.from_tensor_slices(is_train)
    .shuffle(1000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))

valid_dataset = (
    tf.data.Dataset.from_tensor_slices(is_test)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

In [122]:
# Train
num_training_updates = 10000

train_losses = []
recon_losses = []
for step_index, data in enumerate(train_dataset):
    train_results = train_step(data)
    train_losses.append(train_results['loss'])
    recon_losses.append(train_results['x_recon_loss'])
    
    if (step_index + 1) % 100 == 0:
        print('%d loss: %f recon loss: %f' % (step_index+1, np.mean(train_losses[-100:]), np.mean(recon_losses[-100:])))
        
    if (step_index + 1) % num_training_updates == 0:
        break

100 loss: 7.106334 recon loss: 6.937060
200 loss: 6.213257 recon loss: 5.606678
300 loss: 5.912413 recon loss: 5.103334
400 loss: 5.801593 recon loss: 4.939638
500 loss: 5.695118 recon loss: 4.728037
600 loss: 5.595801 recon loss: 4.456376
700 loss: 5.556584 recon loss: 4.302936
800 loss: 5.534577 recon loss: 4.263422
900 loss: 5.522062 recon loss: 4.239207
1000 loss: 5.534912 recon loss: 4.244057
1100 loss: 5.523646 recon loss: 4.226972
1200 loss: 5.510580 recon loss: 4.208869
1300 loss: 5.514606 recon loss: 4.206976
1400 loss: 5.501201 recon loss: 4.196276
1500 loss: 5.499741 recon loss: 4.185254
1600 loss: 5.513560 recon loss: 4.195140
1700 loss: 5.496112 recon loss: 4.176683
1800 loss: 5.500017 recon loss: 4.181344
1900 loss: 5.506378 recon loss: 4.184157
2000 loss: 5.494060 recon loss: 4.173867
2100 loss: 5.497131 recon loss: 4.167235
2200 loss: 5.497489 recon loss: 4.169603
2300 loss: 5.498637 recon loss: 4.163640
2400 loss: 5.496807 recon loss: 4.166883
2500 loss: 5.498518 recon

# Results

In [123]:
# Look at validation set
model_output = catvae(is_test.astype(int))
print(model_output)

{'x_recon': <tf.Tensor: shape=(1999, 10, 2), dtype=float32, numpy=
array([[[0.715759  , 0.28424105],
        [0.8606288 , 0.13937123],
        [0.9390734 , 0.0609266 ],
        ...,
        [0.96785396, 0.03214609],
        [0.9381952 , 0.06180481],
        [0.83525103, 0.16474903]],

       [[0.21704927, 0.78295076],
        [0.08958674, 0.91041327],
        [0.03436255, 0.9656374 ],
        ...,
        [0.00699824, 0.99300176],
        [0.01736357, 0.98263645],
        [0.08281987, 0.9171801 ]],

       [[0.85000676, 0.14999323],
        [0.93578887, 0.06421115],
        [0.96387565, 0.03612437],
        ...,
        [0.3102007 , 0.6897993 ],
        [0.28209776, 0.71790224],
        [0.34256205, 0.657438  ]],

       ...,

       [[0.36071262, 0.63928735],
        [0.26478264, 0.73521733],
        [0.17304373, 0.8269563 ],
        ...,
        [0.2753079 , 0.7246921 ],
        [0.34822333, 0.6517767 ],
        [0.3925557 , 0.60744435]],

       [[0.556508  , 0.44349197],
        [0

In [139]:
def bit_array(a):
    aa = []
    for c in a:
        if c == '0':
            aa.append(0)
        else:
            aa.append(1)
        
    return np.array(aa)

def bin_to_dec(b):
    dec = 0
    for idx, val in enumerate(b):
        dec += val << (len(b) - idx - 1)
        
    return dec

def update_counts(psi, vae, batch_size):
    latent_dim = vae.encoder.latent_dim
    z = tf.random.normal([batch_size, latent_dim], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32)
    output = vae.decoder(z)
    vdim = output['x_recon'].shape[1]
    
    probs = tf.reshape(output['x_recon'], [-1, output['x_recon'].shape[-1]])
    samples = tf.reshape(tf.random.categorical(tf.math.log(probs), 1), [batch_size, vdim]).numpy()
    
    for ii in range(samples.shape[0]):
        idx = bin_to_dec(samples[ii,:])
        psi[idx] += 1

def get_psi(vae, num_samples):
    n = vae.decoder.vdim
    psi = np.zeros(2**n)
    batch_size = 1000
    total_samples = 0
    while total_samples < num_samples:
        update_counts(psi, vae, batch_size)
        total_samples = total_samples + batch_size
        
    # Normalize
    psi = np.sqrt(psi*(1.0/float(total_samples)))
    
    return psi

import math
def get_psi_loss(vae, num_samples):
    n = vae.decoder.vdim
    norm = 0
    psi = []
    for d in range(2**n):
        dbin = bit_array(np.binary_repr(d, width=n))
        dbin_input = np.tile(dbin, (num_samples,1))
        model_output = vae(dbin_input.astype(float))
        val = np.exp(-0.5*model_output['loss'])
        psi.append(val)
        norm = norm + val*val
    norm = math.sqrt(norm)
    
    for ii in range(len(psi)):
        psi[ii] = psi[ii]/norm
        
    return np.array(psi)

In [140]:
psi = get_psi(catvae, 1000000)

In [141]:
np.dot(psi, psi)

1.0

In [143]:
psi

array([0.28364943, 0.14852609, 0.10748023, ..., 0.10787956, 0.14757032,
       0.28491402])

In [144]:
# Save the wave function
np.savetxt('/Users/tuckerkj/Google Drive/Research/QML/data/quc_examples/Tutorial1_TrainPosRealWaveFunction/ld_results/cat_vae_psi_2d.txt', psi)


# Debug

In [111]:
encoder = vae.CatEncoder(intermediate_dim, latent_dim, depth)

In [112]:
encoder_output = encoder(tf.cast(is_test, tf.dtypes.int32))

In [113]:
encoder_output['z'].shape

TensorShape([1999, 2])

In [114]:
decoder = vae.CatDecoder(intermediate_dim, original_dim, depth)

In [115]:
decoder_output = decoder(encoder_output['z'])

In [116]:
decoder_output

{'x_recon': <tf.Tensor: shape=(1999, 10, 2), dtype=float32, numpy=
 array([[[0.3283945 , 0.6716055 ],
         [0.24753737, 0.7524627 ],
         [0.48549405, 0.51450604],
         ...,
         [0.84731066, 0.15268925],
         [0.59798235, 0.40201768],
         [0.48342746, 0.5165726 ]],
 
        [[0.37522176, 0.6247783 ],
         [0.5033799 , 0.49662015],
         [0.52120245, 0.4787976 ],
         ...,
         [0.42249188, 0.5775081 ],
         [0.35086215, 0.6491378 ],
         [0.71415263, 0.28584743]],
 
        [[0.28186828, 0.7181317 ],
         [0.22574237, 0.77425766],
         [0.50255436, 0.49744564],
         ...,
         [0.8840027 , 0.11599734],
         [0.63462514, 0.3653749 ],
         [0.48483124, 0.5151687 ]],
 
        ...,
 
        [[0.24717675, 0.7528232 ],
         [0.45836413, 0.5416359 ],
         [0.79527956, 0.20472042],
         ...,
         [0.3245397 , 0.67546034],
         [0.07193395, 0.928066  ],
         [0.9683022 , 0.03169781]],
 
        [[

In [117]:
catvae = vae.CatVAE(encoder, decoder)

In [118]:
model_output = catvae(is_test)

In [119]:
model_output

{'x_recon': <tf.Tensor: shape=(1999, 10, 2), dtype=float32, numpy=
 array([[[0.15239543, 0.8476045 ],
         [0.5892857 , 0.4107143 ],
         [0.3920526 , 0.6079474 ],
         ...,
         [0.31885427, 0.6811457 ],
         [0.21401262, 0.78598744],
         [0.893004  , 0.10699599]],
 
        [[0.27435914, 0.7256409 ],
         [0.2201425 , 0.7798576 ],
         [0.50451887, 0.4954811 ],
         ...,
         [0.89069873, 0.10930119],
         [0.6404785 , 0.3595215 ],
         [0.48500305, 0.51499695]],
 
        [[0.349085  , 0.650915  ],
         [0.4758491 , 0.52415085],
         [0.70569   , 0.2943099 ],
         ...,
         [0.398407  , 0.601593  ],
         [0.17368652, 0.8263135 ],
         [0.88728225, 0.11271767]],
 
        ...,
 
        [[0.05984537, 0.9401547 ],
         [0.30925807, 0.69074196],
         [0.623171  , 0.37682903],
         ...,
         [0.96722454, 0.0327754 ],
         [0.7883375 , 0.2116625 ],
         [0.6104322 , 0.38956785]],
 
        [[

In [124]:
vdim = decoder_output['x_recon'].shape[1]
    
probs = tf.reshape(decoder_output['x_recon'], [-1, decoder_output['x_recon'].shape[-1]])

In [125]:
probs

<tf.Tensor: shape=(19990, 2), dtype=float32, numpy=
array([[0.3283945 , 0.6716055 ],
       [0.24753737, 0.7524627 ],
       [0.48549405, 0.51450604],
       ...,
       [0.73668116, 0.26331884],
       [0.5975779 , 0.402422  ],
       [0.536826  , 0.46317393]], dtype=float32)>

In [127]:
samples = tf.random.categorical(tf.math.log(probs), 1)

In [131]:
tf.reshape(samples, [1999, vdim])

<tf.Tensor: shape=(1999, 10), dtype=int64, numpy=
array([[1, 0, 1, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 1, 1],
       [0, 1, 1, ..., 0, 1, 0],
       ...,
       [1, 0, 1, ..., 1, 1, 0],
       [1, 1, 1, ..., 0, 0, 1],
       [0, 0, 0, ..., 1, 0, 0]])>