In [None]:
import sys
#import os
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
import jax
import time
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
from IPython.display import display, clear_output

from src import util
from src import util_jax
from src.util import tprint
from src.Architecture import get_arch
from src.Architecture import delete_arch
import src.architecture_import as architecture_import

from src.GaussKernel import GaussKernel
from src.LinearKernelCombination import LinearKernelCombination
from src.demo.demo2 import demo2

from src.steps.GaussInput import GaussInput
from src.steps.NeuralField import NeuralField
from src.steps.StaticGain import StaticGain
from src.steps.Normalization import Normalization
from src.steps.TransferFunction import TransferFunction
from src.steps.Convolution import Convolution
from src.steps.DemoInput import DemoInput

In [None]:
print("Computing devices found by JAX:")
print(jax.local_devices())

In [None]:
util_jax.get_config()["euler_step_static_precompile"] =  False
arch = get_arch()

In [None]:
delete_arch()
arch = get_arch()
shape = (294, 447)

# Prepare Input
demo_input_0 = DemoInput("in0", {"shape":shape, "sigma":(30,30), "amplitude":0, "center": (0,40)})
demo_input_1 = DemoInput("in1", {"shape":shape, "sigma":(30,30), "amplitude":0, "center": (120,40)})
demo_input_2 = DemoInput("in2", {"shape":shape, "sigma":(30,30), "amplitude":0, "center": (240,40)})

# Prepare Neural Fields
exc_kernel = GaussKernel({"sigma": (30,30), "amplitude": 24, "normalized": True})
inh_kernel = GaussKernel({"sigma": (55,55), "amplitude": -18, "normalized": True})
nf_kernel = LinearKernelCombination({"kernels": [exc_kernel, inh_kernel], "wheights": [0.5,0.5]})

nf0 = NeuralField("Attention Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00025, "tau": 0.03,  "input_noise_gain": 0.5, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})
nf1 = NeuralField("Action Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00, "tau": 0.03,  "input_noise_gain": 0.1, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})

# Static steps
st0 = StaticGain("st0", {"factor": 5})

# connections
demo_input_0 >> nf0
demo_input_1 >> nf0
demo_input_2 >> nf0
nf0 >> st0
st0 >> nf1

arch.compile()

<img src="img/demo2_sketch.png" alt="demo sketch" width="400">

In [None]:
num_steps = 70
demo = demo2(arch, eye_field="Attention Field", hand_field="Action Field")
recording = demo.run(num_steps)

In [None]:
demo.plot()