In [None]:
import os
import tensorflow as tf
import keras
from keras.layers import Add,Multiply,Softmax,Input,TimeDistributed,Dense,Average,GlobalAveragePooling1D,Concatenate,Lambda,RepeatVector, Conv2D,ConvLSTM2D, MaxPooling2D,BatchNormalization,Flatten,Reshape,UpSampling2D
from keras.models import Model, load_model
from keras.optimizers import Adam
from keras.utils import plot_model
import numpy as np
import matplotlib.pyplot as plt
import math
import time
import pylab as pl
from IPython import display
from IPython.core.display import HTML
from IPython.core.display import display as html_width
html_width(HTML("<style>.container { width:90% !important; }</style>"))
import tensorflow_probability as tfp
import matplotlib.image as mpimg
from matplotlib.gridspec import GridSpec
import imageio
from tqdm import tqdm
from keras.utils.vis_utils import plot_model


from keras.backend.tensorflow_backend import set_session
#physical_devices = tf.config.experimental.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(physical_devices[0], True)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
sess = tf.Session(config=config)
set_session(sess) 

In [None]:
from tensorflow.python.client import device_lib
def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]
print(get_available_devices()) 
print(tf.__version__)

In [None]:
obs_max = 5
train_N = 40
train_p = np.random.permutation(40)

# Deep Modality Blending Networks

In [None]:
def get_train_sample(action_type = -1, coef = -1):
    n = np.random.randint(0,obs_max)+1
    d = train_p[np.random.randint(0, train_N)]
    if action_type == -1:
        action_type = np.random.randint(0,2)
        action_type=0
        if action_type == 0:
            action_type = 'move'
        else:
            action_type = 'grasp'
    if coef == -1:
        coef = np.random.rand()
    img_coef = np.ones((1,128)) * coef
    pose_coef = np.ones((1,128)) * (1-coef)
    observation = np.zeros((1,n,128,128,4)) 
    observation_pose = np.zeros((1,n,8)) 
    target_X = np.zeros((1,1))
    target_Y = np.zeros((1,128,128,6))
    target_Y_pose = np.zeros((1,14))
    pose = np.loadtxt('../data/data2020/%d/%s/joint_%d.txt'%(d,action_type,d))
    pose[:,-1] *= 10
    time_len = pose.shape[0]
    times = np.linspace(0,1,time_len)
    perm = np.random.permutation(time_len)
    for i in range(n):
        observation[0,i,:,:,0] = np.ones((128,128))*times[perm[i]]
        observation[0,i,:,:,1:] = mpimg.imread('../data/data2020/%d/%s/%d.jpeg'%(d,action_type,perm[i]))/255.
        observation_pose[0,i,0] = times[perm[i]]
        observation_pose[0,i,1:] = pose[perm[i]]
    target_X[0,0] = times[perm[n]]
    target_Y[0,:,:,:3] = mpimg.imread('../data/data2020/%d/%s/%d.jpeg'%(d,action_type,perm[n]))/255.
    target_Y_pose[0,:7] = pose[perm[n]]
    return [observation, observation_pose, target_X, img_coef, pose_coef], [target_Y, target_Y_pose], d, perm[n]

In [None]:
def custom_loss(y_true, y_predicted):
    mean, log_sigma = tf.split(y_predicted, 2, axis=-1)
    y_true_value, temp =tf.split(y_true,2,axis=-1)
    sigma = tf.nn.softplus(log_sigma)
    dist = tfp.distributions.MultivariateNormalDiag(loc=mean, scale_diag=sigma)
    loss = -tf.reduce_mean(dist.log_prob(y_true_value))
    return loss

In [None]:
image_layer = Input(shape=(None,128,128,4), name="image_observation") 
joint_layer = Input(shape=(None,8), name="joint_observation") 
target_X_layer = Input(shape=(1,), name = 'target_X')
img_coef_layer = Input(shape=(128,), name = 'image_coef')
pose_coef_layer = Input(shape=(128,), name = 'pose_coef')

encoder_joint_sizes = [64,64,128,128,256]

joint_encoder = TimeDistributed(Dense(32, activation = 'relu'))(joint_layer)
for channel_size in encoder_joint_sizes:
    joint_encoder = TimeDistributed(Dense(channel_size, activation = 'relu'))(joint_encoder)

joint_representations = TimeDistributed(Dense(128, activation='relu'))(joint_encoder) #128
joint_representation = GlobalAveragePooling1D()(joint_representations) 

multiplied_joint = Multiply()([joint_representation,pose_coef_layer])

###################################################

encoder_img_sizes = [64,64,128,128,256]

image_encoder = TimeDistributed(Conv2D(32,(3,3),padding='same',activation='relu'))(image_layer)
image_encoder = TimeDistributed(MaxPooling2D((2,2)))(image_encoder)
for channel_size in encoder_img_sizes:
    image_encoder = TimeDistributed(Conv2D(channel_size,(3,3),padding='same',activation='relu'))(image_encoder)
    image_encoder = TimeDistributed(MaxPooling2D((2,2)))(image_encoder)

image_flatten = TimeDistributed(Flatten())(image_encoder)
img_representations = TimeDistributed(Dense(128, activation='relu'))(image_flatten)
img_representation = GlobalAveragePooling1D()(img_representations) 

multiplied_img = Multiply()([img_representation,img_coef_layer])

general_representation = Add()([multiplied_joint,multiplied_img])

merged_layer = Concatenate(axis=-1, name='merged')([general_representation,target_X_layer])

####################################################

decoder_representation = Dense(1024, activation='relu') (merged_layer)

" =============== Image Decoder =============== "
decoder_img = Reshape([2,2,256])(decoder_representation)
decoder_img_sizes = [256,128,128,64,64,32]

for channel_size in decoder_img_sizes:
    decoder_img = Conv2D(channel_size, (3,3), padding='same', activation='relu')(decoder_img)
    decoder_img = UpSampling2D((2, 2))(decoder_img)

img_output = Conv2D(16, (3,3), padding='same', activation='relu')(decoder_img)
img_output = Conv2D(8, (3,3), padding='same', activation='relu')(img_output)
img_output = Conv2D(6, (3,3), padding='same', activation='sigmoid')(img_output)
" =============== Image Decoder =============== "

" =============== Joint Decoder =============== "
decoder_joint = Dense(512, activation='relu')(decoder_representation)
decoder_joint = Dense(216, activation='relu')(decoder_joint)
decoder_joint = Dense(128, activation='relu')(decoder_joint)
decoder_joint = Dense(32, activation='relu')(decoder_joint)
joint_output = Dense(14)(decoder_joint)
" =============== Joint Decoder =============== "
model = Model([image_layer, joint_layer, target_X_layer, img_coef_layer, pose_coef_layer],[img_output,joint_output])
latent_model = Model([image_layer, joint_layer, target_X_layer, img_coef_layer, pose_coef_layer],general_representation)
model.compile(optimizer = Adam(lr = 1e-4),loss=custom_loss, loss_weights=[1,0.01])
model.summary()
#plot_model(model)

In [None]:
loss_checkpoint = 1000
plot_checkpoint = 1000
validation_checkpoint = 1000
validation_error = 9999999
validation_step = -1
max_training_step = 1000000

dataset = ['image','joint']

float_formatter = "{:.4f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

for step in range(max_training_step):
    inp, out, _, _ = get_train_sample()
    callback = model.fit(inp,out)
    '''
    if step % validation_checkpoint == 0:
        pass
    '''
    if step % plot_checkpoint == 0:
        #clearing output cell
        display.clear_output(wait=True)
        display.display(pl.gcf())
        print(step)
        #plotting on-train examples by user given observations
        inp, out, d_id, target_t = get_train_sample()
        plt.imshow(out[0][0,:,:,:3])
        plt.show()
        plt.imshow(model.predict(inp)[0][0,:,:,:3])
        plt.show()
        print(out[1][0,:7])
        print(model.predict(inp)[1][0,:7])

        inp, out, d_id, target_t = get_train_sample()
        plt.imshow(out[0][0,:,:,:3])
        plt.show()
        plt.imshow(model.predict(inp)[0][0,:,:,:3])
        plt.show()
        print(out[1][0,:7])
        print(model.predict(inp)[1][0,:7])