In [None]:
import tensorflow as tf

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import os
import datetime
from models import Encoder, Decoder, DisentangleVAE
import constants as c

In [None]:
def npy_header_offset(npy_path):
    with open(str(npy_path), 'rb') as f:
        if f.read(6) != b'\x93NUMPY':
            raise ValueError('Invalid NPY file.')
        version_major, version_minor = f.read(2)
        if version_major == 1:
            header_len_size = 2
        elif version_major == 2:
            header_len_size = 4
        else:
            raise ValueError('Unknown NPY file version {}.{}.'.format(version_major, version_minor))
        header_len = sum(b << (8 * i) for i, b in enumerate(f.read(header_len_size)))
        header = f.read(header_len)
        if not header.endswith(b'\n'):
            raise ValueError('Invalid NPY file.')
        return f.tell()
    
def get_split_dataset(path):
    num_feats = 64 * 64 * 3
    # dtype = tf.uint8
    # path = os.path.abspath('./carracing_data/car{}/{}/obs.npy'.format(car_no, split_no))
    # header_offset = npy_header_offset(path)

    dataset = tf.data.FixedLengthRecordDataset([path], num_feats * tf.uint8.size, header_bytes=128)
    dataset = dataset.map(lambda s: tf.reshape(tf.cast(tf.io.decode_raw(s, tf.uint8), dtype=tf.float32) / 255., (64, 64, 3)))
    return dataset

In [None]:
dataset_list = []
for car_no in range(1, 6):
    dataset_list.append(get_split_dataset(os.path.abspath('./carracing_data/car{}/{}/obs.npy'.format(car_no, 0))))
    for split in range(1, 10):
        dataset_list[-1] = dataset_list[-1].concatenate(get_split_dataset(os.path.abspath('./carracing_data/car{}/{}/obs.npy'.format(car_no, split))))
    dataset_list[-1] = dataset_list[-1].shuffle(5000).batch(c.BATCH_SIZE)
dataset = tf.data.Dataset.zip(tuple(dataset_list)).prefetch(tf.data.experimental.AUTOTUNE)

## VAE

In [None]:
encoder = Encoder(mu_only=False)
decoder = Decoder()
model = DisentangleVAE(encoder, decoder)

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=c.INIT_LR))

In [None]:
model.fit(dataset, epochs=1)

In [None]:
#save model weights
model.encoder.save_weights('./encoder_weights_2.h5')
model.decoder.save_weights('./decoder_weights_2.h5')

In [None]:
model.encoder.build([None, 64, 64, 3])
model.decoder.build([None, 24])

In [None]:
model.encoder.load_weights('./encoder_weights.h5')
model.decoder.load_weights('./decoder_weights.h5')

In [None]:
a = dataset.take(1).as_numpy_iterator().next()

In [None]:
a[0].shape

In [None]:
a = tf.stack(a)

In [None]:
a.shape

In [None]:
a[0].shape

In [None]:
mu, logsigma, classcode = model.encode(a[0])

In [None]:
mu2, logsigma2, classcode2 = model.encode(a[1])

In [None]:
plt.imshow(a[1][3])

In [None]:
plt.imshow(a[0][3])

In [None]:
plt.imshow(model.decode(mu2[3:4], classcode2[3:4])[0])

In [None]:
tf.reduce_sum(model.decode(mu2[3:4], classcode2[3:4])[0] - model.decode(mu[3:4], classcode2[3:4])[0])

In [None]:
mu_r = tf.random.normal(shape=[1, 16])
plt.imshow(model.decode(mu_r, classcode[3:4])[0])

In [None]:
mu2[3]