In [1]:
using Flux
using Plots
using Parameters
import Random
using StatsBase # for random choice
Random.seed!(1234);

###### Data

In [2]:
in_features = 2
out_features = 1
N = 48

π_32 = Float32(π)
t = range(0.0f0,stop=3π_32, length = N)
sin_t = sin.(t)
cos_t = cos.(t)
data_x = reshape(hcat(sin_t, cos_t), (1, N, 2))
data_y = reshape(sin.(range(0.0f0,stop=6π_32, length = N)), (1, N, 1))#.astype(np.float32)

println(size(data_x))
println(size(data_y))

# Plots
#plot(data_x[:,:,1]')
#plot!(data_x[:,:,2]')
#plot!(data_y[1,:,1])

(1, 48, 2)
(1, 48, 1)


###### Wiring

In [3]:
# Macro for inserting fields 
macro def(name, definition) 
  return quote
      macro $(esc(name))()
          esc($(Expr(:quote, definition)))
      end
  end
end

# Wiring
# Do we really need the exact code as one for the biology?
abstract type Wiring end
@def wiring_fields begin
     units
     adjacency_matrix
     sensory_adjacency_matrix
     input_dim
     output_dim
end

# Outer constructor
#Wiring(units::Int64) = Wiring(units, zeros((units, units)), Nothing, Nothing)

# Methods
function _build(wiring::Wiring, input_shape) # for Wiring type
    _input_dim = convert(Int, input_shape[2])
    # do assert check
    @show wiring.input_dim
    if wiring.input_dim === nothing
        set_input_dim(wiring, _input_dim)
    end
    @show wiring.input_dim
end

function add_synapse(adjacency_matrix, src, dest, polarity)
    adjacency_matrix[src, dest] = polarity
    #adjacency_matrix
end

function _init_add_synapse(units, adjacency_matrix, self_conn)
    for src in (1:units)
        for dest in (1:units)
            if src === dest && !self_conn
                continue
            end
            polarity = StatsBase.sample([-1,1,1])
            add_synapse(adjacency_matrix, src, dest, polarity)
            #println(src, dest)
        end
    end
    #@show adjacency_matrix
end

function add_sensory_synapse(sensory_adjacency_matrix, src, dest, polarity)
    sensory_adjacency_matrix[src, dest] = polarity
    #adjacency_matrix
end

# not used?
function set_input_dim(wiring::Wiring, _input_dim)
    wiring.input_dim = _input_dim
    units = wiring.units
    wiring.sensory_adjacency_matrix = zeros((_input_dim, units))
end 

function _erev_initializer(wiring::Wiring, shape=nothing, dtype=nothing) # dtype?
    copy(wiring.adjacency_matrix)
end

function _sensory_erev_initializer(wiring::Wiring, shape=nothing, dtype=nothing) # dtype?
    copy(wiring.sensory_adjacency_matrix)
end

function _state_size(wiring::Wiring)
    return wiring.units
end
# Test 
#set_input_dim(wiring, 2)

_state_size (generic function with 1 method)

###### Fully Connected

In [4]:
mutable struct FullyConnected <: Wiring
    @wiring_fields
    self_conns # default value ?
    # Inner constructor
    function FullyConnected(units, _input_dim = nothing, _output_dim = nothing, self_conns = true) #arguments order and call?
        adjacency_matrix = zeros((units, units))
        sensory_adjacency_matrix = nothing
        if _output_dim === nothing
            output_dim = units
        else
            output_dim = _output_dim
        end
        _init_add_synapse(units, adjacency_matrix, self_conns)
        return new(units, adjacency_matrix, sensory_adjacency_matrix, _input_dim,
                   output_dim, self_conns)
    end
end

In [5]:
# Methods
function build(wiring::FullyConnected, input_shape)
    _build(wiring, (nothing, in_features)) # from Wiring
    @show wiring
    input_dim = wiring.input_dim
    units = wiring.input_dim
    sensory_adjacency_matrix = wiring.sensory_adjacency_matrix
    for src in (1:input_dim)
        for dest in (1:units)
#             if src === dest && !self_conn
#                 continue
#             end
            polarity = StatsBase.sample([-1,1,1])
            @show polarity
            add_sensory_synapse(sensory_adjacency_matrix, src, dest, polarity)
            #println(src, dest)
        end
    end
end

build (generic function with 1 method)

### LTC cell

###### Definition


In [6]:
struct LTCCell
    _wiring::Wiring        
    _init_ranges
    _input_mapping::String
    _output_mapping::String
    _ode_unfolds::Int
    _epsilon
    _params::Dict # Is this trainable in Dict format?
    state0 # for Recur
    function LTCCell(wiring, in_features = nothing, input_mapping="affine",
            output_mapping="affine", ode_unfolds=6, epsilon=1e-8, params=Dict(),
            state0=zeros(1, 8))
        if in_features !== nothing
            build(wiring, (nothing, in_features))
        end
        # Is this proper place?
        init_ranges = _get_init_ranges()
        params = _allocate_parameters(wiring, params, init_ranges, input_mapping, output_mapping) 
        # How to launch it inside?
        new(wiring, init_ranges, input_mapping, output_mapping, ode_unfolds,
            epsilon, params, state0) # state0:(batch, units)? 
        #
    end
end

In [7]:
function _get_init_ranges()
    _init_ranges = Dict(
            "gleak"=> (0.001, 1.0),
            "vleak"=> (-0.2, 0.2),
            "cm"=> (0.4, 0.6),
            "w"=> (0.001, 1.0),
            "sigma"=> (3, 8),
            "mu"=> (0.3, 0.8),
            "sensory_w"=> (0.001, 1.0),
            "sensory_sigma"=> (3, 8),
            "sensory_mu"=> (0.3, 0.8))
end

_get_init_ranges (generic function with 1 method)

In [8]:
# For initializing fields
function _get_init_value(shape, name, init_ranges)
    minval, maxval = init_ranges[name]
    if minval === maxval
        return ones(shape) * minval
    else
        return rand(Float64, shape) * (maxval - minval) .+ minval
    end
end

# Not all values here need init_ranges! Decompose it
function _init_weights_and_params(wiring, name::String, init_ranges)
    state_size = _state_size(wiring)
    _sensory_size = wiring.input_dim
    if name in ["gleak", "vleak", "cm"]
        return _get_init_value(state_size, name, init_ranges)
    elseif name in ["sigma", "mu", "w"]
        return _get_init_value((state_size, state_size), name, init_ranges)
    elseif name in ["sensory_sigma", "sensory_mu", "sensory_w"]
        return _get_init_value((_sensory_size, state_size), name, init_ranges)
    elseif name in ["erev"]     
         return _erev_initializer(wiring) 
    elseif name in ["sensory_erev"]     
        return _sensory_erev_initializer(wiring) 
    elseif name in ["sparsity_mask"]
        return abs.(wiring.adjacency_matrix)
    elseif name in ["sensory_sparsity_mask"]    
        return abs.(wiring.sensory_adjacency_matrix)
    end
end

# Init all weights
function _allocate_parameters(wiring, params, init_ranges, input_mapping, output_mapping)
    println("alloc!")
    
    _params_keys = ["sigma", "mu", "w", "sensory_sigma", "sensory_mu", "sensory_w",
                    "erev", "sensory_erev", "gleak", "vleak", "cm", "sparsity_mask",
                    "sensory_sparsity_mask"]
    #println(fieldnames(wiring), "We're here")
    _motor_size = wiring.output_dim
    _sensory_size = wiring.input_dim
    for _key in _params_keys
        @show _key
        #_init_weight()
        params[_key] = _init_weights_and_params(wiring, _key, init_ranges)
    end
    # It is new fields !!!
    if input_mapping in ["affine", "linear"]
        params["input_w"] = ones((1, _sensory_size))
    end
    if input_mapping in ["affine"]
        params["input_b"] = zeros((1, _sensory_size))
    end
    if output_mapping in ["affine", "linear"]
        params["output_w"] = ones(_motor_size)
    end
    if output_mapping in ["affine"]
        params["output_b"] = zeros(_motor_size)
    end
    return params
end

# Additional
function _state_size(ltc::LTCCell)
    return _state_size(ltc._wiring)
end

_state_size (generic function with 2 methods)

### ODE

In [9]:
function _map_inputs(ltc::LTCCell, inputs)
    if ltc._input_mapping in ["affine", "linear"]
        inputs = inputs .* ltc._params["input_w"] # Element-wise 
    end
    if ltc._input_mapping === "affine"
        inputs = inputs .+ ltc._params["input_b"]
    end
    @assert size(inputs) === (1,2)
    return inputs
end

function _map_outputs(ltc::LTCCell, state)
    output = state
    _motor_size = ltc._wiring.output_dim

    if _motor_size < _state_size(ltc)
        println("!oups")
        #output = output[:, 0 : self.motor_size]  # slice
    end
    if ltc._output_mapping in ["affine", "linear"]
        output = output .* reshape(ltc._params["output_w"], (1, size(ltc._params["output_w"])...)) # Element-wise 
    end
    if ltc._output_mapping === "affine"
        output = output .+ reshape(ltc._params["output_w"], (1, size(ltc._params["output_b"])...)) # Element-wise 
    end
    #@assert size(inputs) === (1,2)
    return output    
end

_map_outputs (generic function with 1 method)

In [10]:
function _sigmoid(v_pre, mu, sigma)
    v_pre = reshape(v_pre, (size(v_pre)...,1)) # add batch dim
    mu = reshape(mu, 1, size(mu)...) # for dims match ?
    
    mues = v_pre .- mu 
  
    x = map(x_ -> (sigma .* x_), eachslice(mues, dims=1))[end]
    x = reshape(x, 1, size(x)...) # for dims match
    return σ.(x)
end

_sigmoid (generic function with 1 method)

In [11]:
function complicated_prod(a, b; dim=1)
    # This is for keeping some dim restriction during broadcasting
    out = map(x -> a .* x, eachslice(b, dims=dim))[end]
    reshape(out, 1, size(out)...) # for dims match ?
end

complicated_prod (generic function with 1 method)

In [12]:
function _ode_solver_(ltc::LTCCell, inputs, state, elapsed_time)
    v_pre = state
    # We can pre-compute the effects of the sensory neurons here
    println("Again we need slice(((")
    sensory_w_activation = complicated_prod(ltc._params["sensory_w"],
                                            _sigmoid(
                                                     inputs, ltc._params["sensory_mu"],
                                                     ltc._params["sensory_sigma"]))    
    #@show size(sensory_w_activation)
    #@show size(ltc._params["sensory_sparsity_mask"])
    sensory_w_activation =  complicated_prod(ltc._params["sensory_sparsity_mask"], 
                                             sensory_w_activation,)
    
    sensory_rev_activation = complicated_prod(ltc._params["sensory_erev"], sensory_w_activation)

    # Reduce over dimension 1 (=source sensory neurons)
    w_numerator_sensory = dropdims(sum(sensory_rev_activation, dims=2), dims=2)
    w_denominator_sensory =  dropdims(sum(sensory_w_activation, dims=2), dims=2)


# cm/t is loop invariant
cm_t = ltc._params["cm"] / (elapsed_time / ltc._ode_unfolds)
# Unfold the multiply ODE multiple times into one RNN step
    for t = 1:ltc._ode_unfolds
        w_activation =  complicated_prod(ltc._params["w"], 
                                        _sigmoid(
                                                v_pre, ltc._params["mu"],
                                        ltc._params["sigma"]))
        w_activation = complicated_prod(ltc._params["sparsity_mask"], w_activation) 

        rev_activation = complicated_prod(ltc._params["erev"], w_activation)
        # Reduce over dimension 1 (=source neurons)
        w_numerator = dropdims(sum(rev_activation, dims=2), dims=2) .+ w_numerator_sensory
        #w_numerator = sum(rev_activation, dim=2) + w_numerator_sensory
        #w_denominator = sum(w_activation, dim=2) + w_denominator_sensory
        w_denominator = dropdims(sum(w_activation, dims=2), dims=2) .+ w_denominator_sensory        
 
        numerator = (
            reshape(cm_t, (1, size(cm_t)...)) .* v_pre 
            .+ reshape(ltc._params["gleak"], (1, size(ltc._params["gleak"])...))  
            .* reshape(ltc._params["vleak"], (1, size(ltc._params["vleak"])...))
            .+ w_numerator
        )
        denominator = reshape(cm_t, (1, size(cm_t)...)) 
             .+ reshape(ltc._params["gleak"], (1, size(ltc._params["gleak"])...))  
             .+ w_denominator
        # Avoid dividing by 0
        v_pre = numerator ./ (denominator .+ ltc._epsilon)
        #@show size(v_pre)
    end
    return v_pre
end

_ode_solver_ (generic function with 1 method)

### Forward

In [13]:
function (ltc::LTCCell)(state, x) 
    println("inside")
    batch_size = size(x)[1] # batch dim
    seq_len = size(x)[1]
    hidden_state = zeros(batch_size, _state_size(ltc))
    t = 1 # should be loop for sequence or just "." ?
    inputs = x[:, t, :]
    inputs = _map_inputs(ltc, inputs)
    next_state = _ode_solver_(ltc, inputs, hidden_state, 1.0)# The 2nd arg is init state, ltc instead of RNN
    outputs = _map_outputs(ltc, next_state)
    return next_state, outputs # for Recur
end

# Grad test

In [14]:
Flux.trainable(ltc::LTCCell) = (ltc._params["sensory_w"],) 

In [15]:
wiring = FullyConnected(8)
LTC = Flux.Recur(LTCCell(wiring, in_features), rand(1,8)) # toy state
LTC(data_x)

wiring.input_dim = nothing
wiring.input_dim = 2
wiring = FullyConnected(8, [-1.0 1.0 1.0 -1.0 1.0 -1.0 1.0 1.0; 1.0 1.0 -1.0 -1.0 1.0 1.0 -1.0 1.0; 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0; 1.0 -1.0 -1.0 -1.0 -1.0 1.0 1.0 -1.0; 1.0 1.0 1.0 1.0 1.0 -1.0 1.0 -1.0; -1.0 1.0 -1.0 1.0 1.0 1.0 -1.0 1.0; -1.0 1.0 1.0 1.0 1.0 -1.0 1.0 -1.0; -1.0 1.0 1.0 1.0 1.0 1.0 -1.0 1.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], 2, 8, true)
polarity = 1
polarity = -1
polarity = -1
polarity = -1
alloc!
_key = "sigma"
_key = "mu"
_key = "w"
_key = "sensory_sigma"
_key = "sensory_mu"
_key = "sensory_w"
_key = "erev"
_key = "sensory_erev"
_key = "gleak"
_key = "vleak"
_key = "cm"
_key = "sparsity_mask"
_key = "sensory_sparsity_mask"
inside
Again we need slice(((


1×8 Array{Float64,2}:
 1.39656  0.782549  1.77031  2.25922  1.36065  1.53868  1.05892  1.79086

In [16]:
@show params(LTC)

grads = Flux.gradient(params(LTC)) do
     @show sum(LTC(data_x))
end

params(LTC) = Params([[0.4248197446008244 0.6908924455877022 0.11536149878993325 0.07144266615163494 0.826478572788968 0.6693851911184779 0.4862827935623121 0.8432499804449927; 0.41455007502452434 0.393331359807225 0.7390817741758764 0.2581449122326856 0.34275247560166827 0.6756128032096728 0.39329941905039223 0.391323832511804]])
inside
Again we need slice(((
sum(LTC(data_x)) = 11.957754241565564


Grads(...)

In [17]:
for p in grads
    println(p)
end

[0.01909806832161306 -0.3189497312710684 0.0 -0.0 0.0 0.0 0.0 0.0; -4.909365361761815 -1.7367534971362426 0.0 -0.0 0.0 0.0 0.0 0.0]


In [18]:
# make all the params as separeate fields