# Different perturbations of FORCE networks
This notebook trains 30 instances of FORCE networks for each perturbation type in figure 3. Our implementation of the FORCE algorithm is perhaps not the fastest, and this can take a while to run.

In [None]:
include("../forceNetworks.jl")
using PyCall
using ProgressMeter
@pyimport numpy
@pyimport pandas
using PyPlot

In [None]:
function trainAndRun!(network, input)
    T = size(input, 2)
    #states = Vector{ForceNetwork.NetworkState}()
    @ProgressMeter.showprogress 1 "Simulating..." for t=1:T
        target = t % 10 == 0 && t < .9T ? input[:,t] : nothing# && t % 2000 > 400 
        if t % 10000 < 3000
            ForceNetwork.step!(network, target, input[:,t], 0.0)
        else
            recFrac = max(1-t/(.8T), 0.0)
            ForceNetwork.step!(network, target, input[:,t], recFrac)
        end
        #push!(states, deepcopy(network.state))
    end
    #return states
end

In [None]:
function testForceCorr(Q, i, pertDescr)
    N = 1000
    D = 2
    signal = [0.2 0.6 0.0 -0.7 -0.2; 0.0 -0.3 0.2 0.0 0.5]
    dt = 5e-5
    input = repeat(signal, inner=(1,Int(5e-1/dt)), outer=(1,20))
    net = ForceNetwork.Network(ForceNetwork.NetworkParameters(), ForceNetwork.NetworkMatrices(N, D, α=.1*dt, G=10.0), ForceNetwork.NetworkState(N, D))
    trainAndRun!(net, input)
    key = @sprintf "/force/%s/%d/original" pertDescr i;
    pandas.DataFrame(net.matrices.η, columns=["K", "K"])[:to_hdf]("../generatedData/fig3.h5", key*"/K")
    pandas.DataFrame(net.matrices.Φ', columns=["phi1", "phi2"])[:to_hdf]("../generatedData/fig3.h5", key*"/phi")
    
    qMat = ForceNetwork.NetworkMatrices(net.matrices.Ω, .1*dt*eye(N), net.matrices.η*Q, zeros(D, N))
    qNet = ForceNetwork.Network(ForceNetwork.NetworkParameters(), qMat, ForceNetwork.NetworkState(N, D))
    trainAndRun!(qNet, input)
    key = @sprintf "/force/%s/%d/perturbed" pertDescr i;
    pandas.DataFrame(qNet.matrices.η, columns=["K", "K"])[:to_hdf]("../generatedData/fig3.h5", key*"/K")
    pandas.DataFrame(qNet.matrices.Φ', columns=["phi1", "phi2"])[:to_hdf]("../generatedData/fig3.h5", key*"/phi")
end

In [None]:
@showprogress "identity" for i=1:30
    Q = eye(2)
    testForceCorr(Q, i, "identity")
end
@showprogress "permutation" for i=1:30
     Q = [0 1;1 0]
    testForceCorr(Q, i, "permutation")
end
@showprogress "normal" for i=1:30
    Q = randn(2,2)
    testForceCorr(Q, i, "normal")
end
@showprogress "constant" for i=1:30
    q = randn()
    Q = [q q;q q]
    testForceCorr(Q, i, "constant")
end
deg = pi/4
Q = [cos(deg) -sin(deg); cos(deg) sin(deg)]
@showprogress "rotation45" for i=1:30
    testForceCorr(Q, i, "rotation45")
end

In [None]:
function trainAndTestCorrOuter(rowPerm, i, pertDescr="outer")
    N = 1000
    K = 2
    signal = [0.2 0.6 0.0 -0.7 -0.2; 0.0 -0.3 0.2 0.0 0.5]
    dt = 5e-5
    input = repeat(signal, inner=(1,Int(5e-1/dt)), outer=(1,20))
    net = ForceNetwork.Network(ForceNetwork.NetworkParameters(), ForceNetwork.NetworkMatrices(N, D, α=.1*dt, G=10.0), ForceNetwork.NetworkState(N, D))
    runFix!(net, input)
    key = @sprintf "/force/%s/%d/original" pertDescr i;
    pandas.DataFrame(net.matrices.η, columns=["K1", "K2"])[:to_hdf]("../generatedData/fig3.h5", key*"/K")
    pandas.DataFrame(net.matrices.Φ', columns=["phi1", "phi2"])[:to_hdf]("../generatedData/fig3.h5", key*"/phi")
    qMat = ForceNetwork.NetworkMatrices(net.matrices.Ω, .1*dt*eye(N), net.matrices.η[rowPerm, :], zeros(D, N))
    qNet = ForceNetwork.Network(ForceNetwork.NetworkParameters(), qMat, ForceNetwork.NetworkState(N, D))
    runFix!(qNet, input)
    key = @sprintf "/force/%s/%d/perturbed" pertDescr i;
    pandas.DataFrame(qNet.matrices.η, columns=["K1", "K2"])[:to_hdf]("../generatedData/fig3.h5", key*"/K")
    pandas.DataFrame(qNet.matrices.Φ', columns=["phi1", "phi2"])[:to_hdf]("../generatedData/fig3.h5", key*"/phi")
end

In [None]:
rowPerm = [collect(501:1000); collect(1:500)];
@showprogress "outerPermutation" for i=25:30
    trainAndTestCorrOuter(rowPerm, i, "outerPermutation")
end
#@showprogress "outerIdentity" for i=1:30
#    trainAndTestCorrOuter(1:1000, i, "outerIdentity")
#end