In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow_probability import distributions as tfd
import build_data.video_read as data_utils
import model.model as model_utils

In [None]:
tf.test.is_gpu_available()

In [None]:
if logging is None:
    # The logging module may have been unloaded when __del__ is called.
    log_fn = print
else:
    log_fn = logging.warning

In [None]:
# Hyperparameters.
seed = 0
tf.random.set_seed(seed)
batch_size = 1
num_frames = 15
num_slots = 8
slot_size = 32
sample_steps_num = 100
learning_sample = 0.04
resolution = (128, 128)
model_dir = "XPL/XPL/"
perception_dir = "XPL/perception/"

decode_type = "SBTD"
encode_type = "ViT"

In [None]:
def get_video_ED(num_frames, resolution, batch_size, num_slots, slot_size,
                 encode_type, decode_type, file):
    model_pre = model_utils.build_percept_model(resolution,
                                        batch_size * num_slots * num_frames,
                                        num_channels=4,
                                        slot_size=slot_size,
                                        decode_type=decode_type,
                                        encode_type=encode_type)
    ckpt_pre = tf.train.Checkpoint(network=model_pre)
    ckpt_manager_pre = tf.train.CheckpointManager(ckpt_pre,
                                                  directory=file,
                                                  max_to_keep=5)
    if ckpt_manager_pre.latest_checkpoint:
        ckpt_pre.restore(ckpt_manager_pre.latest_checkpoint).expect_partial()
        log_fn("Restored from {}".format(ckpt_manager_pre.latest_checkpoint))
    model_enc = model_pre.get_layer("ObjectEncoder")
    model_dec = model_pre.get_layer("ObjectDecoder")
    return model_enc, model_dec, model_pre

In [None]:
model_enc, model_dec, model_pre = get_video_ED(num_frames, resolution, batch_size,
                               num_slots, slot_size, encode_type,
                               decode_type, perception_dir)
model_reason = model_utils.build_IN_LSTM(batch_size,
                                num_slots,
                                slot_size,
                                num_frames=num_frames,
                                use_camera=False)
model_fast = model_utils.build_fast_model(batch_size, num_frames,
                                          num_slots, slot_size)
model_new = model_utils.build_fast_model(batch_size, num_frames,
                                         num_slots, slot_size)
ckpt = tf.train.Checkpoint(network=model_reason)
ckpt_F = tf.train.Checkpoint(network=model_fast)
ckpt_N = tf.train.Checkpoint(network=model_new)
ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt,
                                          directory=model_dir + "/dynamic",
                                          max_to_keep=1)
ckpt.restore(ckpt_manager.latest_checkpoint)
ckpt_manager_F = tf.train.CheckpointManager(checkpoint=ckpt_F,
                                            directory=model_dir +
                                            "/explain1",
                                            max_to_keep=1)
ckpt_F.restore(ckpt_manager_F.latest_checkpoint)
ckpt_manager_N = tf.train.CheckpointManager(checkpoint=ckpt_N,
                                            directory=model_dir +
                                            "/explain2",
                                            max_to_keep=1)
ckpt_N.restore(ckpt_manager_N.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
    log_fn("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    log_fn("Initializing from scratch.")
if ckpt_manager_F.latest_checkpoint:
    log_fn("Restored from {}".format(ckpt_manager_F.latest_checkpoint))
else:
    log_fn("Initializing from scratch.")
if ckpt_manager_N.latest_checkpoint:
    log_fn("Restored from {}".format(ckpt_manager_N.latest_checkpoint))
else:
    log_fn("Initializing from scratch.")

In [None]:
def renormalize(x):
    """Renormalize from [-1, 1] to [0, 1]."""
    x = tf.clip_by_value(x, -1.0, 1.0)
    return x / 2. + 0.5


def data_init(img, mask):
    object_img = tf.expand_dims(img, axis=2) * mask + (1 - mask)
    mask_mean = tf.reduce_sum(mask, axis=[3, 4])
    mask_sum = tf.reduce_sum(mask, axis=2)
    mask_sum = tf.clip_by_value(mask_sum, 0.0, 1.0)
    new_image = img * mask_sum + (1 - mask_sum)
    object_img = tf.reshape(object_img,
                            shape=[-1] + object_img.shape.as_list()[3:])
    mask_mean = tf.where(mask_mean > 4, 1.0, 0.0)
    return object_img, new_image, mask_sum, mask_mean


def decode_objects_occ(objects, mask_sum):
    DM_factor = -1000.0
    (B, F, N, V) = objects.shape
    objects = tf.reshape(objects, shape=[-1, V])
    recons, depth, slots = model_dec(objects)
    recons = tf.reshape(recons, shape=[-1, F, N] + recons.shape.as_list()[1:])
    depth = tf.reshape(depth, shape=[-1, F, N] + depth.shape.as_list()[1:])
    recons_occ, _ = tf.split(recons, [7, 1], axis=2)
    depth_occ, _ = tf.split(depth, [7, 1], axis=2)
    masks = tf.nn.softmax(depth_occ * DM_factor, axis=2)
    masks = masks * tf.expand_dims(mask_sum, axis=2)
    recon_combined = tf.reduce_sum(recons_occ * masks, axis=2)
    recon_combined = recon_combined * mask_sum + 1.0 * (1 - mask_sum)
    return recon_combined, recons, depth


def cal_all(batch, model_fast, model_new, model_reason):
    img = batch['image']
    mask = batch['mask']
    image, _, _, _ = data_init(img, mask)
    # pre_slots = batch['slot']
    pre_slots = model_enc(image)
    pre_slots = tf.reshape(pre_slots,
                           shape=[-1, num_frames, num_slots, slot_size * 2])
    B = pre_slots.shape[0]
    pre_slots = tf.nn.sigmoid(pre_slots)
    random_noise = tf.reduce_sum(mask, axis=[1, 3, 4, 5])
    random_slots = tf.argsort(random_noise, axis=1)
    restore_slots = tf.argsort(random_slots, axis=1)
    mask = tf.gather(mask, random_slots, axis=2, batch_dims=-1)
    pre_slots = tf.gather(pre_slots, random_slots, axis=2, batch_dims=-1)
    mask_mean = tf.reduce_sum(mask, axis=[3, 4])
    mask_sum = tf.reduce_sum(mask, axis=2)
    mask_mean = tf.where(mask_mean > 4, 1.0, 0.0)
    mask_mean = mask_mean + tf.cast(
        tf.reduce_sum(mask_mean, axis=1, keepdims=True) < 1, tf.float32)
    mask_axes = tf.concat([
        tf.zeros([1, num_frames, 1, 1]),
        tf.ones([1, num_frames, num_slots - 1, 1])
    ],
                          axis=2)
    mask_axes = tf.tile(mask_axes, [B, 1, 1, 1])
    pre_slots_dist = (pre_slots[:, :, :, :slot_size] * 6.0 - 3.0,
                      pre_slots[:, :, :, slot_size:] * 3.0)
    objects_pre = pre_slots_dist[0]
    objects_enc = tf.gather(objects_pre, restore_slots, axis=2, batch_dims=-1)
    recons_image_enc, recons_enc, depth_enc = decode_objects_occ(
        objects_enc, mask_sum)
    objects_init = model_fast((objects_pre, mask_mean), training=False)
    objects_init = model_new((objects_init, mask_axes), training=False)
    objects_init_out = tf.gather(objects_init,
                                 restore_slots,
                                 axis=2,
                                 batch_dims=-1)
    recons_image_fast, recons_fast, depth_fast = decode_objects_occ(
        objects_init_out, mask_sum)
    objects_2 = model_reason(objects_init_out)
    objects_2 = objects_init_out + objects_2
    objects_2 = tf.roll(objects_2, shift=1, axis=1)
    recons_image, recons, depth = decode_objects_occ(objects_2, mask_sum)
    return (objects_enc, recons_image_enc, recons_enc,
            depth_enc), (objects_init_out, recons_image_fast, recons_fast,
                         depth_fast), (objects_2, recons_image, recons, depth)


In [None]:
collision_ds =  data_utils.load_data(batch_size,
                                             split="blocking",
                                             shuffle=False)
collision_ds = collision_ds.skip(2000)
iterator = iter(collision_ds)
batch = next(iterator)

In [None]:
enc_all, explain_all, dynamic_all = cal_all(batch, model_fast, model_new, model_reason)

In [None]:
(objects, recon_combined, recons, depth) = enc_all
fig, ax = plt.subplots(num_frames, num_slots + 2, figsize=(15, 2 * num_frames))
for i in range(num_frames):
    ax[i, 0].imshow(renormalize(batch['image'][0,i][..., :3]))
    # ax[i,0].set_title('Image')
    ax[i, 1].imshow(renormalize(recon_combined[0,i][..., :3]))
    # ax[i,1].set_title('Recon.')
    # recons, masks=sort(recons[i], masks[i])
    for j in range(num_slots):
        ax[i, j + 2].imshow(renormalize(recons[0, i, j][..., :3]))
# ax[0].imshow(image)
ax[0, 0].set_title('Image')
# ax[1].imshow(recon_combined)
ax[0, 1].set_title('Recon.')
for i in range(num_slots):
    # ax[i + 2].imshow(recons[i] * masks[i] + (1 - masks[i]))
    # ax[i + 2].imshow(recons[i])
    ax[0, i + 2].set_title('Slot %s' % str(i + 1))
for i in range(num_frames):
    for j in range(num_slots + 2):
        ax[i, j].grid(False)
        ax[i, j].axis('off')

In [None]:
(objects_fast, recon_combined, recons, depth) = explain_all
fig, ax = plt.subplots(num_frames, num_slots + 2, figsize=(15, 2 * num_frames))
for i in range(num_frames):
    ax[i, 0].imshow(renormalize(batch['image'][0,i][..., :3]))
    # ax[i,0].set_title('Image')
    ax[i, 1].imshow(renormalize(recon_combined[0,i][..., :3]))
    # ax[i,1].set_title('Recon.')
    # recons, masks=sort(recons[i], masks[i])
    for j in range(num_slots):
        ax[i, j + 2].imshow(renormalize(recons[0, i, j][..., :3]))
# ax[0].imshow(image)
ax[0, 0].set_title('Image')
# ax[1].imshow(recon_combined)
ax[0, 1].set_title('Recon.')
for i in range(num_slots):
    # ax[i + 2].imshow(recons[i] * masks[i] + (1 - masks[i]))
    # ax[i + 2].imshow(recons[i])
    ax[0, i + 2].set_title('Slot %s' % str(i + 1))
for i in range(num_frames):
    for j in range(num_slots + 2):
        ax[i, j].grid(False)
        ax[i, j].axis('off')

In [None]:
(objects_reason, recon_combined, recons, depth) = dynamic_all
fig, ax = plt.subplots(num_frames, num_slots + 2, figsize=(15, 2 * num_frames))
for i in range(num_frames):
    ax[i, 0].imshow(renormalize(batch['image'][0,i][..., :3]))
    # ax[i,0].set_title('Image')
    ax[i, 1].imshow(renormalize(recon_combined[0,i][..., :3]))
    # ax[i,1].set_title('Recon.')
    # recons, masks=sort(recons[i], masks[i])
    for j in range(num_slots):
        ax[i, j + 2].imshow(renormalize(recons[0, i, j][..., :3]))
# ax[0].imshow(image)
ax[0, 0].set_title('Image')
# ax[1].imshow(recon_combined)
ax[0, 1].set_title('Recon.')
for i in range(num_slots):
    # ax[i + 2].imshow(recons[i] * masks[i] + (1 - masks[i]))
    # ax[i + 2].imshow(recons[i])
    ax[0, i + 2].set_title('Slot %s' % str(i + 1))
for i in range(num_frames):
    for j in range(num_slots + 2):
        ax[i, j].grid(False)
        ax[i, j].axis('off')