In [1]:
using Revise

In [2]:
using JLD2
using FileIO
import LearningMPC



In [3]:
addprocs(10)

10-element Array{Int64,1}:
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11

In [None]:
@everywhere using Flux
@everywhere using MLDataPattern
@everywhere using CoordinateTransformations
@everywhere using ProgressMeter
@everywhere import FluxExtensions

In [5]:
samples = load("../2018-02-07-hopper-smaller-grid/grid_search.jld2")["samples"];

In [6]:
filter!(samples) do sample
    sample.state[2] <= sample.state[1] && !any(isnan, sample.uJ)
end;
length(samples)

59931

In [7]:
features(sample::LearningMPC.Sample) = (sample.state, sample.uJ[:, 1], sample.uJ[:, 2:end])
data = features.(samples);
train_data, test_data = splitobs(shuffleobs(data), at=0.85);

In [8]:
@everywhere function attention_model()
    signals = Chain(Dense(4, 16))
    weights = Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        softmax)
    model = FluxExtensions.Attention(signals, weights)
    @assert size(model(zeros(4))) == (1,)
            
    loss = (x, u, J) -> Flux.mse(model(x), u[2])
    model, loss
end

@everywhere function layered_model()
    model = Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        Dense(16, 1))
    
    @assert size(model(zeros(4))) == (1,)
    loss = (x, u, J) -> Flux.mse(model(x), u[2])
    model, loss
end

@everywhere function attention_model_tangent(weight=0.05)
    signals = Chain(Dense(4, 16))
    weights = Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        softmax,
    )
    model = FluxExtensions.TangentPropagator(FluxExtensions.Attention(signals, weights))
            
    @assert size.(model(zeros(4))) == ((1,), (1, 4))
    
    function loss(x, u, J)
        uhat, Jhat = model(x)
        (1 - weight) * Flux.mse(uhat, u[2]) + weight * Flux.mse(Jhat, J[2, :])
    end
    function test_loss(x, u, J)
        uhat, Jhat = model(x)
        Flux.mse(uhat, u[2])
    end
    model, loss, test_loss
end

@everywhere function layered_model_tangent(weight=0.05)
    model = FluxExtensions.TangentPropagator(Chain(
        LinearMap(UniformScaling(0.2)),
        Dense(4, 16, elu),
        Dense(16, 16, elu),
        Dense(16, 1)))
            
    @assert size.(model(zeros(4))) == ((1,), (1, 4))
    
    function loss(x, u, J)
        uhat, Jhat = model(x)
        (1 - weight) * Flux.mse(uhat, u[2]) + weight * Flux.mse(Jhat, J[2, :])
    end
    function test_loss(x, u, J)
        uhat, Jhat = model(x)
        Flux.mse(uhat, u[2])
    end
    model, loss, test_loss
end

In [None]:
iterations = 200

attention_losses = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss = attention_model()
    opt = Flux.ADADelta(params(model))
    (model, map(1:iterations) do i
        Flux.train!(loss, train, opt)
        (mean(xy -> Flux.Tracker.value(loss(xy...)), validation),)
    end)
end
@show attention_losses

layered_losses = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss = layered_model()
    opt = Flux.ADADelta(params(model))
    (model, map(1:iterations) do i
        Flux.train!(loss, train, opt)
        (mean(xy -> Flux.Tracker.value(loss(xy...)), validation),)
    end)
end
@show layered_losses

attention_losses_tangent = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss, test_loss = attention_model_tangent()
    opt = Flux.ADADelta(params(model))
    (model, map(1:iterations) do i
        Flux.train!(loss, train, opt)
        (mean(xy -> Flux.Tracker.value(loss(xy...)), validation), mean(xy -> Flux.Tracker.value(test_loss(xy...)), validation))
    end)
end
@show attention_losses_tangent

layered_losses_tangent = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss, test_loss = layered_model_tangent()
    opt = Flux.ADADelta(params(model))
    (model, map(1:iterations) do i
        Flux.train!(loss, train, opt)
        (mean(xy -> Flux.Tracker.value(loss(xy...)), validation), mean(xy -> Flux.Tracker.value(test_loss(xy...)), validation))
    end)
end
@show layered_losses_tangent



In [18]:
save("losses2.jld2", "attention", attention_losses, "layered", layered_losses, "attention_tangent", attention_losses_tangent, "layered_tangent", layered_losses_tangent)

In [19]:
using Plots; gr()

Plots.GRBackend()

In [20]:
typeof(attention_losses[1])

Tuple{FluxExtensions.Attention{Flux.Chain,Flux.Chain},Array{Tuple{Float64},1}}

In [28]:
plt = plot(legend=nothing)
for (model, losses) in attention_losses
    plot!(plt, last.(losses), linecolor="red")
end
for (model, losses) in layered_losses
    plot!(plt, last.(losses), linecolor="blue")
end
for (model, losses) in attention_losses_tangent
    plot!(plt, last.(losses), linecolor="pink")
end
for (model, losses) in layered_losses_tangent
    plot!(plt, last.(losses), linecolor="skyblue")
end
plt

In [24]:
last.(last.(attention_losses))

10-element Array{Tuple{Float64},1}:
 (21.9033,)
 (19.6214,)
 (18.2502,)
 (21.8882,)
 (18.1163,)
 (19.6783,)
 (17.9095,)
 (19.0924,)
 (17.264,) 
 (20.5518,)

In [25]:
last.(last.(layered_losses))

10-element Array{Tuple{Float64},1}:
 (33.1317,)
 (21.0533,)
 (18.466,) 
 (23.562,) 
 (25.8672,)
 (23.8732,)
 (32.903,) 
 (19.8653,)
 (20.4046,)
 (20.7992,)



In [None]:
layered_losses_tangent_0 = @parallel (vcat) for (train, validation) in collect(kfolds(train_data; k=10))
    model, loss, test_loss = layered_model_tangent(0.0)
    opt = Flux.ADADelta(params(model))
    (model, map(1:iterations) do i
        Flux.train!(loss, train, opt)
        (mean(xy -> Flux.Tracker.value(loss(xy...)), validation), mean(xy -> Flux.Tracker.value(test_loss(xy...)), validation))
    end)
end
@show layered_losses_tangent_0

In [32]:
plt = plot(legend=nothing)
for (model, losses) in layered_losses_tangent_0
    plot!(plt, last.(losses), linecolor="red")
end
for (model, losses) in layered_losses
    plot!(plt, last.(losses), linecolor="blue")
end
plt

In [35]:
layered_losses_tangent_0[1][2]

200-element Array{Tuple{Float64,Float64},1}:
 (69.2365, 69.2365)
 (51.4347, 51.4347)
 (38.8203, 38.8203)
 (34.6861, 34.6861)
 (31.8719, 31.8719)
 (29.9075, 29.9075)
 (28.6035, 28.6035)
 (27.9436, 27.9436)
 (27.5948, 27.5948)
 (27.3777, 27.3777)
 (27.2521, 27.2521)
 (27.1708, 27.1708)
 (27.111, 27.111)  
 ⋮                 
 (21.2752, 21.2752)
 (21.291, 21.291)  
 (21.2961, 21.2961)
 (21.2972, 21.2972)
 (21.3027, 21.3027)
 (21.3099, 21.3099)
 (21.3188, 21.3188)
 (21.33, 21.33)    
 (21.3442, 21.3442)
 (21.3607, 21.3607)
 (21.3788, 21.3788)
 (21.3979, 21.3979)