In [None]:
import sys
#import os
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
import jax
import time
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
from IPython.display import display, clear_output

from src import util
from src import util_jax
from src.util import tprint
from src.Architecture import get_arch
from src.Architecture import delete_arch
import src.architecture_import as architecture_import

from src.GaussKernel import GaussKernel
from src.LinearKernelCombination import LinearKernelCombination
from src.demo.demo1 import demo1

from src.steps.GaussInput import GaussInput
from src.steps.NeuralField import NeuralField
from src.steps.StaticGain import StaticGain
from src.steps.Normalization import Normalization
from src.steps.TransferFunction import TransferFunction
from src.steps.Convolution import Convolution
from src.steps.DemoInput import DemoInput

In [None]:
print("Computing devices found by JAX:")
print(jax.local_devices())

In [None]:
util_jax.get_config()["euler_step_static_precompile"] =  False
arch = get_arch()

<img src="img/demo1_sketch.png" alt="demo sketch" width="400">

In [None]:
delete_arch()
arch = get_arch()
shape = (294, 447)

# Static steps
demo_input = DemoInput("in0", {"shape":shape, "sigma":(30,30), "amplitude":0, "center": (0,0)})
st0 = StaticGain("st0", {"factor": 6})

# field params
exc_kernel = GaussKernel({"sigma": (30,30), "amplitude": 16, "normalized": True})
inh_kernel = GaussKernel({"sigma": (55,55), "amplitude": -3, "normalized": True})
nf_kernel = LinearKernelCombination({"kernels": [exc_kernel, inh_kernel], "wheights": [0.5,0.5]})
# Dynamic steps
nf0 = NeuralField("Attention Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00, "tau": 0.05,  "input_noise_gain": 0, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})
nf1 = NeuralField("Action Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00, "tau": 0.1, "input_noise_gain": 0, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})

# connections
demo_input >> nf0
nf0 >> st0
st0 >> nf1

arch.compile()

In [None]:
num_steps = 60
demo = demo1(arch, eye_field="Attention Field", hand_field="Action Field")
recording = demo.run(num_steps)

In [None]:
demo.plot()