In [5]:
import numpy as np
from env import Env
from model import Model
from plotter import Plotter
from funcs import g
from mkvideo import vidManager


### Constants

In [6]:

real_mu = 0*np.pi
model_mu = 0*np.pi
model_rho = -0.35*np.pi
stime = 20000

rng = np.random.RandomState()


 ### Simulation

We simulate three cases.

1. The precision of the gprocess and that of the gmodel are the save_frame
2. The precision of the gmodel is a little bit higher of the model's None
3. The precision of the gmodel is higher of the model's None


In [7]:
%%capture
from IPython.display import display

gprocess_sigmas = 0.2, 0.2, 0.2
gmodel_sigmas = 0.1, 0.11, 0.12

title = "gp\_sigma=%6.4f gm\_sigma=%6.4f"
n_sim = -1
for gprocess_sigma, gmodel_sigma in zip(gprocess_sigmas, gmodel_sigmas):

    n_sim +=1

    plotter = Plotter(
        time_window=stime,
        title=title % (gprocess_sigma, gmodel_sigma))

    vidMaker = vidManager(plotter.figure, name="sim_%d"%n_sim,
                          dirname="sim", duration=80)

    # init the generative model (agent) and the generative process
    # (environment)
    gprocess = Env(rng)
    gmodel = Model(rng, mu=model_mu, rho=model_rho)

    gprocess.set_sigma(gprocess_sigma)
    gmodel.set_sigma(gmodel_sigma)

    state = gprocess.reset(mu=real_mu)
    for t in range(stime):

        # Update model via gradient descent and get action
        action = gmodel.update(state)

        # Generated fake sensory state from model
        gstate = gmodel.gstate

        # do action
        state = gprocess.step(action)

        # update plot every n steps
        plotter.append_mu(gprocess.mu, gmodel.mu)
        if t % 1000 == 0 or t == stime-1:
            plotter.sensed_arm.update(state[0], state[1:])
            plotter.real_arm.update(gprocess.istate[0], gprocess.istate[1:])
            plotter.generated_arm.update(gstate[0], gstate[1:])
            plotter.target_arm.update(gmodel.rho, g(gmodel.rho))
            plotter.update()
            vidMaker.save_frame()
    vidMaker.mk_video()

In [14]:
from IPython.display import HTML, display
def display_gif(fn):
    display(fn + ":")
    display(HTML('<img src="{}">'.format(fn)))

import glob

files = sorted(glob.glob("sim/sim_*gif"))
for file in files:
    display_gif(file)

'sim/sim_0.gif:'

'sim/sim_1.gif:'

'sim/sim_2.gif:'