In [1]:
%load_ext line_profiler
#%matplotlib inline
from kusanagi.ghost.algorithms import ExperienceDataset
from kusanagi import utils
from kusanagi.ghost.regression import GP
utils.set_logfile('/dev/null')
from matplotlib import pyplot as plt
from matplotlib import animation, rc
import numpy as np
from IPython.display import HTML

# interpolator
gp = GP(idims=1,odims=2)

In [2]:
exp_sim = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/cartpole/PILCO_SSGP_UI_Cartpole_RBFPolicy_sat_dataset')
exp_pilco = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/cartpole_serial/PILCO_SSGP_UI_SerialPlant_RBFPolicy_sat_dataset')
exp_27g_2 = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/target_27g_run_2/TrajectoryMatching_SSGP_UI_SerialPlant_AdjustedPolicy_dataset')
exp_45g_2 = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/target_45g_run_2/TrajectoryMatching_SSGP_UI_SerialPlant_AdjustedPolicy_dataset')
exp_90g_2 = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/target_90g_run_2/TrajectoryMatching_SSGP_UI_SerialPlant_AdjustedPolicy_dataset')
exp_180g_1 = ExperienceDataset(filename='/home/juancamilog/workspace/kusanagi/examples/learned_policies/target_180g_run_1/TrajectoryMatching_SSGP_UI_SerialPlant_AdjustedPolicy_dataset')

In [74]:
def init_lines(ax, exp, ep_indices,
               linestyles = ['--','x--', 'o--', 'v--', '^--'],
               c0=[0.95,0.2,0.2,1.0],
               c1=[0.0,0.9,0.0,1.0],
               max_timesteps=-1,
               markevery=1
              ):
    x = [exp.states[ep] for ep in ep_indices]
    c0 = np.array(c0)
    c1 = np.array(c1)
    lines = []
    H = []
    t0 = []
    ticks = []
    for i, ep_idx in enumerate(ep_indices):
        color = c1*i*1.0/len(x) + c0*(len(x)-i)*1.0/len(x)
        line, = ax.plot([],[], linestyles[i%len(linestyles)], 
                    color=tuple(color.tolist()), 
                    linewidth=2, markevery=markevery,
                    markerfacecolor='None',
                    markeredgecolor=color,
                    mew=2,
                    ms=9)
        lines.append(line)
        tf = exp.time_stamps[ep_idx][max_timesteps]
        ti = exp.time_stamps[ep_idx][0]
        if ti < 0:
            ti = exp.time_stamps[ep_idx][1] - (exp.time_stamps[ep_idx][2]-exp.time_stamps[ep_idx][1])
            exp.time_stamps[ep_idx][0] = ti
            
        H.append((tf-ti))
        t0.append(ti)
        ticks.append(np.array(exp.time_stamps[ep_idx])[:,None]-t0[-1])
        
    return lines, x, ticks, t0, H

In [81]:
def get_animation(source_exp, target_exp, source_episodes, target_episodes,
                  source_title, target_title, gp, max_timesteps_source=30, 
                  max_timesteps_target=35, anim_dt=0.05,
                  data_dt=0.1, dims= [0,3], animate_source = True):
    episode_delay = int(0.0/anim_dt)
    # create figure 
    plt.close()
    fig, ax = plt.subplots(figsize=(8,8),dpi=80)

    # create lines for distribution of source trajectories
    ret_s = init_lines(ax,source_exp,source_episodes,max_timesteps=max_timesteps_source,
                       linestyles=['--'], c0=(0.2,0.2,0.2,0.5),c1=(0.2,0.2,0.2,0.5))
    
    # create lines for each episode in the target domain
    ret_t = init_lines(ax,target_exp,target_episodes,max_timesteps=max_timesteps_target)

    # concatentate source and target lines
    lines,trajs,ticks,t0,H = [ret_s_i+ret_t_i for ret_s_i, ret_t_i in zip(ret_s,ret_t)]
    steps = [0]
    steps = reduce(lambda arr,x: arr + [arr[-1]+int(x/anim_dt)],H,steps)
    start_frame = 0 if animate_source else steps[len(ret_s[0])] + episode_delay*(len(ret_s[0]) - 1)
    total_frames = int(steps[-1] + episode_delay*(len(steps) - 1)) - start_frame 

    # init function for animation
    curr_ep = [0]
    interp_data = [0]
    def init_anim():
        print('init_anim...')
        curr_ep[0] = 0 if animate_source else len(ret_s[0])
        interpolated_data = []
        min_ = []
        max_ = []
        labels = []
        for i, ep in enumerate(np.array(trajs)):
            utils.print_with_stamp('Processing episode %d of %d'%(i+1,len(trajs)),same_line=True, use_log=False)
            gp.set_dataset(X_dataset=ticks[i],Y_dataset=ep[:,dims])
            gp.init_params()
            gp.train()
            ep_i = np.array([gp.predict(t)[0] for t in np.arange(0,H[i],anim_dt)[:,None]])
            interpolated_data.append(ep_i)
            lines[i].set_data([],[])
            min_.append(ep_i.min(0))
            max_.append(ep_i.max(0))
            if i < len(ret_s[0]):
                labels.append("Source trajs. (prior examples)")
                if not animate_source :
                    lines[i].set_data(ep_i[:,0],ep_i[:,1])
            elif target_episodes[0]==0 and i==len(ret_s[0]):
                labels.append("Adjusted traj. (SRC)")
            else:
                labels.append("Adjusted traj. (A%d)"%(target_episodes[i-len(ret_s[0])]))

        interp_data[0] = np.array([ep_i for ep_i in interpolated_data])
        min_ = np.array(min_).min(0)
        max_ = np.array(max_).max(0)
        data = interp_data[0]
        #ax.set_xlim([min_[0], max_[0]])
        #ax.set_ylim([min_[1], max_[1]])
        ax.set_xlim(-1.2,1.5)
        ax.set_ylim(-1,13)
        ax.set_xlabel('Cart Position (m)', size=20)
        ax.set_ylabel('Pole Angle (rad)', size=20)
        ax.grid(True)
        ax.set_title(target_title, fontweight='bold')
        origin, = ax.plot([0],[0],'bo',ms=15,markerfacecolor=(0.5,0.5,0.5),markeredgecolor='b')
        goal, = ax.plot([0],[np.pi],'bv',ms=15,markerfacecolor=(0.5,0.5,0.5),markeredgecolor='b')
        labels.append("Start Location")
        labels.append("Goal Location")
        labels = labels[len(ret_s[0])-1:]
        handles = lines[len(ret_s[0])-1:]
        handles.extend([origin,goal])

        ax.legend(handles,labels,numpoints=1)

        print('\ndone!')
        return lines

    # update function
    def update_anim(i):
        #print('update_anim!')
        ep = curr_ep[0]
        t = start_frame + int(i - (steps[ep] + ep*episode_delay))
        utils.print_with_stamp('Processing frame %d of %d (ep: %d t: %d)'%(i+1,total_frames,ep+1,t),same_line=True, use_log=False)
        data = interp_data[0]
        H_steps = data[ep].shape[0]
        lines[ep].set_data(data[ep][:t,0],data[ep][:t,1])
        for ii in range(ep):
            if ii >= len(ret_s[0]):
                a = lines[ii].get_alpha()
                if not a:
                    a = 1.0
                lines[ii].set_alpha(0.25)
                lines[ii].set_markerfacecolor('None')
                lines[ii].set_markerfacecoloralt('None')

        if t == H_steps + episode_delay - 1:
            curr_ep[0] = (curr_ep[0]+1)%len(trajs)
        if ep == len(ret_s[0]):
            ax.set_title(target_title, fontweight='bold', fontsize=20)
        elif ep == 0:
            ax.set_title(source_title, fontweight='bold', fontsize=20)
        return lines
    rc('animation', html='html5')
    anim = animation.FuncAnimation(fig, update_anim, init_func=init_anim, frames=total_frames, interval=1000*anim_dt, blit=True)
    return anim

In [82]:
source_episodes = range(11,17)
target_episodes = range(0,5)

anim = get_animation(exp_pilco, exp_27g_2, source_episodes, target_episodes,
                     "Example trajectories\nunder source dynamics",
                     "Controller adjustment\nunder target dynamics (1.5x mass)",
                     gp, animate_source=False)
HTML(anim.to_html5_video())

init_anim...
[2K[2017-05-25 13:11:17.334676] Processing episode 11 of 11
done!
init_anim...
[2K[2017-05-25 13:11:24.307021] Processing episode 11 of 11
done!
[2K[2017-05-25 13:12:00.173495] Processing frame 355 of 355 (ep: 11 t: 70)

In [83]:
anim = get_animation(exp_pilco, exp_45g_2, source_episodes, target_episodes,
                     "Example trajectories\nunder source dynamics",
                     "Controller adjustment\nunder target dynamics (1.5x mass)",
                     gp, animate_source=True)
HTML(anim.to_html5_video())

init_anim...
[2K[2017-05-25 13:12:44.767586] Processing episode 11 of 11
done!
init_anim...
[2K[2017-05-25 13:12:52.510831] Processing episode 11 of 11
done!
[2K[2017-05-25 13:14:05.442635] Processing frame 720 of 720 (ep: 11 t: 70)

In [84]:
anim = get_animation(exp_pilco, exp_90g_2, source_episodes, target_episodes,
                     "Example trajectories\nunder source dynamics",
                     "Controller adjustment\nunder target dynamics (3x mass)",
                     gp, animate_source=False)
HTML(anim.to_html5_video())

init_anim...
[2K[2017-05-25 13:14:57.844346] Processing episode 11 of 11
done!
init_anim...
[2K[2017-05-25 13:15:05.940643] Processing episode 11 of 11
done!
[2K[2017-05-25 13:15:42.037786] Processing frame 355 of 355 (ep: 11 t: 70)

In [85]:
anim = get_animation(exp_pilco, exp_180g_1, source_episodes, target_episodes,
                     "Example trajectories\nunder source dynamics",
                     "Controller adjustment\nunder target dynamics (5x mass)",
                     gp, animate_source=False)
HTML(anim.to_html5_video())

init_anim...
[2K[2017-05-25 13:15:50.321528] Processing episode 11 of 11
done!
init_anim...
[2K[2017-05-25 13:15:59.050450] Processing episode 11 of 11
done!
[2K[2017-05-25 13:16:35.630903] Processing frame 355 of 355 (ep: 11 t: 70)

In [58]:
import matplotlib
matplotlib.colors.colorConverter.to_rgba('white', alpha=.5)

(1.0, 1.0, 1.0, 0.5)