In [None]:
import sys
#import os
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.25"
import jax
import time
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
from IPython.display import display, clear_output

from src import util
from src import util_jax
from src.util import tprint
from src.Architecture import get_arch
from src.Architecture import delete_arch
import src.architecture_import as architecture_import

from src.GaussKernel import GaussKernel
from src.LinearKernelCombination import LinearKernelCombination
from src.demo.demo_huetchen import demo_huetchen

from src.steps.GaussInput import GaussInput
from src.steps.NeuralField import NeuralField
from src.steps.StaticGain import StaticGain
from src.steps.Normalization import Normalization
from src.steps.TransferFunction import TransferFunction
from src.steps.Convolution import Convolution
from src.steps.DemoInput import DemoInput
from src.steps.TimedBoost import TimedBoost

In [None]:
print("Computing devices found by JAX:")
print(jax.local_devices())

In [None]:
util_jax.get_config()["euler_step_static_precompile"] =  False
arch = get_arch()

<img src="img/demo4_sketch.png" alt="demo sketch" width="400">

In [None]:
delete_arch()
arch = get_arch()
shape = (294,447)

# Static steps
demo_input = DemoInput("in0", {"shape":shape, "sigma":(30,30), "amplitude":0, "center": (0,0)})
boost_input = TimedBoost("boost", {"amplitude":3, "duration": [1.6, 4]})
st0 = StaticGain("st0", {"factor": 6})
st1 = StaticGain("st1", {"factor": 2})

# field params
exc_mem_kernel = GaussKernel({"sigma": (30,30), "amplitude": 300, "normalized": True})
inh_mem_kernel = GaussKernel({"sigma": (55,55), "amplitude": -10, "normalized": True})
memory_kernel = LinearKernelCombination({"kernels": [exc_mem_kernel, inh_mem_kernel], "wheights": [0.5, 0.5]})
exc_kernel = GaussKernel({"sigma": (30,30), "amplitude": 10, "normalized": True})
inh_kernel = GaussKernel({"sigma": (55,55), "amplitude": -3, "normalized": True})
nf_kernel = LinearKernelCombination({"kernels": [exc_kernel, inh_kernel], "wheights": [0.5,0.5]})

# Dynamic steps
nf0 = NeuralField("Memory Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.01, "tau": 0.05,  "input_noise_gain": 0.0, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": memory_kernel})
nf1 = NeuralField("Action Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00, "tau": 0.02,  "input_noise_gain": 0.0, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})
nf2 = NeuralField("Attention Field", {"shape": shape, "resting_level": -5, "global_inhibition": -0.00, "tau": 0.05,  "input_noise_gain": 0.0, "sigmoid": "AbsSigmoid", "beta": 100, "theta":0, "lateral_kernel_convolution": nf_kernel})

# connections
demo_input >> nf2 >> st0 >> nf0 >> st1 >> nf1
boost_input >> nf1

arch.compile()

In [None]:
num_steps = 150
demo = demo_huetchen(arch, eye_field="Attention Field", hand_field="Action Field")
recording = demo.run(num_steps)

In [None]:
demo.plot()