### Embed to Control: A Locally Linear Latent Dynamics Model for Control from Raw Images
#### by Manuel Watter, Jost Tobias Springenberg, Joschka Boedecker, Martin Riedmiller
##### http://arxiv.org/abs/1506.07365

### The MIT License (MIT)

##### Impementation Copyright (c) 2015, John-Alexander M. Assael (iassael@gmail.com) & Marc P. Deisenroth. All rights reserved.

# Parameters

In [None]:
if itorch then
    arg = {}
end

cmd = torch.CmdLine()
cmd:text()
cmd:text('Options')

-- general options:
cmd:option('-seed', 1, 'initial random seed')
cmd:option('-threads', 4, 'number of threads')

-- gpu
cmd:option('-cuda', false, 'cuda')

-- model
cmd:option('-lambda', 0.25, 'lambda')
cmd:option('-action_size', 1, 'action size')

-- training
cmd:option('-batch_size', 25, 'batch size')
cmd:option('-hist_len', 2, 'history length')
cmd:option('-learningRate', 3e-4, 'learning rate')

-- get current path
require 'sys'
dname, fname = sys.fpath()
cmd:option('-save', dname, 'save path')
cmd:option('-load', false, 'load pretrained model')

cmd:option('-v', false, 'be verbose')
cmd:text()

opt = cmd:parse(arg)

# Import Packages

In [None]:
require 'hdf5'
require 'image'
require 'nngraph'
require 'optim'
require 'nn'
require 'unsup'
Plot = require 'itorch.Plot'

require 'modules/base'
require 'modules/KLDistCriterion'
require 'modules/KLDCriterion'
require 'modules/LinearO'
require 'modules/AddCons'
require 'modules/Reparametrize'

-- Cuda initialisation
if opt.cuda then
    require 'cutorch'
    require 'cunn'
    cutorch.setDevice(1)
    print(cutorch.getDeviceProperties(1))
end

torch.manualSeed(opt.seed)
torch.setnumthreads(opt.threads)
-- Set float as default type
torch.setdefaulttensortype('torch.FloatTensor') 

# Load Data

In [None]:
function disp_img(img)
    if itorch then
        if opt.y_mean ~= nil then
            img = g_destandarize(img:float(), opt.y_mean, opt.y_std)
        end
        itorch.image(image.scale(img:float():reshape(opt.img_w, opt.img_h), 256))
    end
end

In [None]:
local myFile = hdf5.open('data/single_pendulum_nogravity.h5', 'r')

local y_all = myFile:read('train_y'):all():float()
local u_all = myFile:read('train_u'):all():float():reshape(y_all:size(1), opt.action_size)

myFile:close()

-- Scale images
-- local new_size = 10
-- local prev_size = torch.sqrt(y_all:size(2))
-- y_all = image.scale(y_all:reshape(y_all:size(1), prev_size,prev_size), new_size, new_size):reshape(y_all:size(1), new_size^2)

-- Train Test
local y = y_all[{{1,4900}}]
local u = u_all[{{1,4900}}]

local ys = y_all[{{4901,5000}}]
local us = u_all[{{4901,5000}}]

-- Update parameters
opt.img_w = torch.sqrt(y:size(2))
opt.img_h = torch.sqrt(y:size(2))
opt.max_seq_length = y:size(1) - 1

-- Store data
state_train = {
  x = transfer_data(y),
  u = transfer_data(u)
}

state_test = {
  x = transfer_data(ys),
  u = transfer_data(us)
}

print('Train=' .. state_train.x:size(1) .. ' Test=' .. state_test.x:size(1) .. ' (' .. opt.img_w .. 'x' .. opt.img_h .. ')')

In [None]:
idx=1
disp_img(state_train.x[idx])

# Define Model Architecture

### Network

In [None]:
function create_network()
    
    opt.latent_dims = 3
    local enc_dims = 100
    local trans_dims = 100
    
    -- Model Specific parameters
    local f_maps_1 = 64
    local f_size_1 = 7
    local f_maps_2 = 32
    local f_size_2 = 5
    local f_maps_3 = 16
    local f_size_3 = 3
    
    -- Encoder
    --layer 1
    encoder = nn.Sequential()
    encoder:add(nn.Reshape(1, opt.hist_len, opt.img_w, opt.img_h))
    encoder:add(nn.SpatialConvolutionMM(opt.hist_len, f_maps_1, f_size_1, f_size_1))
    encoder:add(nn.ReLU())
    encoder:add(nn.SpatialMaxPooling(2,2,2,2))
    
    --layer 2
    encoder:add(nn.SpatialConvolutionMM(f_maps_1, f_maps_2, f_size_2, f_size_2))
    encoder:add(nn.ReLU())
    encoder:add(nn.SpatialMaxPooling(2,2,2,2))
    
    encoder:add(nn.Reshape(32*6*6))
    encoder:add(nn.LinearO(32*6*6, enc_dims))
    encoder:add(nn.ReLU())
        
    encoder:add(nn.LinearO(enc_dims, enc_dims))
    encoder:add(nn.ReLU())
    
    local z = nn.ConcatTable()
    z:add(nn.LinearO(enc_dims, opt.latent_dims))
    z:add(nn.LinearO(enc_dims, opt.latent_dims))
    encoder:add(z)
       
    -- Decoder
    local decoder = nn.Sequential()
    decoder:add(nn.Reparametrize(opt.latent_dims))
    decoder:add(nn.LinearO(opt.latent_dims, enc_dims))
    decoder:add(nn.ReLU())

    decoder:add(nn.LinearO(enc_dims, enc_dims))
    decoder:add(nn.ReLU())    
    
    decoder:add(nn.LinearO(enc_dims, 32*9*9))
    decoder:add(nn.ReLU())
    
    decoder:add(nn.Reshape(32, 9, 9))
    
    -- layer 2
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_2, f_maps_2, f_size_2, f_size_2))
    decoder:add(nn.ReLU())
    
    -- layer 1
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_2, f_maps_1, f_size_2, f_size_2))
    decoder:add(nn.ReLU())
    
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_1, 2, f_size_2+4, f_size_2+4))
    
    decoder:add(nn.Sigmoid())
    decoder:add(nn.View())
    
    
    -- transition
    local trans = nn.Sequential()
    trans:add(nn.Reparametrize(opt.latent_dims))
    trans:add(nn.View(opt.latent_dims))
    trans:add(nn.LinearO(opt.latent_dims, trans_dims))
    trans:add(nn.ReLU())
    trans:add(nn.LinearO(trans_dims, trans_dims))
    trans:add(nn.ReLU())
            
    
    -- Define model
    local x_t = nn.Identity()():annotate{name = 'x_t'}
    local u_t = nn.Identity()():annotate{name = 'u_t'}
    
    -- Define Encoder Module
    local z_t = encoder(x_t):annotate{name = 'z_t'}
    
    -- Define Transition Matrices
    local h_trans = trans(z_t):annotate{name = 'h_trans'}
    
    
    -- transition a
    local matrix_a_vr = nn.Sequential() 
    local z = nn.ConcatTable()
    z:add(nn.LinearO(trans_dims, opt.latent_dims))
    z:add(nn.LinearO(trans_dims, opt.latent_dims))
    matrix_a_vr:add(z)

    local vr = matrix_a_vr(h_trans)
    local v = nn.View(opt.latent_dims,1)(nn.SelectTable(1)(vr))
    local r = nn.View(opt.latent_dims,1)(nn.SelectTable(2)(vr))
    local vtr = nn.View(1,1)(nn.MM()({nn.Transpose({1,2})(v), r}))
    local alpha = nn.Reshape(opt.latent_dims,opt.latent_dims)(
        nn.Replicate(opt.latent_dims*opt.latent_dims)(
            nn.AddConstant(1)(nn.MulConstant(-1)(vtr))
        ))
    local vrt = nn.View(opt.latent_dims,opt.latent_dims)(nn.MM()({v, nn.Transpose({1,2})(r)}))
    local matrix_a = nn.View(opt.latent_dims,opt.latent_dims)(
        nn.AddCons(torch.eye(opt.latent_dims,opt.latent_dims))(
            nn.CDivTable()({vrt, alpha})
        )
    )
    
    -- local matrix_a = nn.Sequential()
    -- matrix_a:add(nn.LinearO(trans_dims, opt.latent_dims*opt.latent_dims))
    -- matrix_a:add(nn.View(opt.latent_dims, opt.latent_dims))
        
    -- transition b
    local matrix_b = nn.Sequential()
    matrix_b:add(nn.LinearO(trans_dims, opt.latent_dims*opt.action_size))
    matrix_b:add(nn.View(opt.latent_dims, opt.action_size))
    
    -- transition o
    local matrix_o = nn.Sequential()
    matrix_o:add(nn.LinearO(trans_dims, opt.latent_dims))
    matrix_o:add(nn.View(opt.latent_dims, 1))
    

    local a_t = matrix_a:annotate{name = 'a_t'}
    -- local a_t = matrix_a(h_trans):annotate{name = 'a_t'}
    local a_t2 = nn.Power(2)(a_t):annotate{name = 'a_t2'}
    local b_t = matrix_b(h_trans):annotate{name = 'b_t'}
    local o_t = matrix_o(h_trans):annotate{name = 'o_t'}
    
    local a_t_z_t = nn.MM()({
            a_t, nn.View(opt.latent_dims, 1)(nn.SelectTable(1)(z_t))
        }):annotate{name = 'a_t_z_t'}
    
    local b_t_u_t = nn.MM()({
            b_t, nn.View(opt.action_size, 1)(u_t)
        }):annotate{name = 'b_t_u_t'}
    
    -- Define Dynamics Model
    local dynamics_mean = nn.Transpose({1,2})(
        nn.CAddTable()({a_t_z_t, b_t_u_t, o_t})
    ):annotate{name = 'dynamics_mean'}
    
    local dynamics_var = nn.Transpose({1,2})(nn.Log()(
        nn.MM()({
            a_t2, nn.View(opt.latent_dims, 1)(nn.Exp()(nn.SelectTable(2)(z_t)))
    }))):annotate{name = 'dynamics_var'}
    
    local dynamics_all = nn.Identity()({dynamics_mean, dynamics_var}):annotate{name = 'dynamics'}
    
    -- Define Output
    local output_t1 = decoder(dynamics_all):annotate{name = 'decoder_x_t1'}
    local decoder2 = decoder:clone("weight", "bias", "gradWeight", "gradBias")
    local output_t = decoder2(z_t):annotate{name = 'decoder_x_t'}
    
    -- Create model
    model = nn.gModule({x_t, u_t}, {z_t, output_t, dynamics_all, output_t1})
    
    -- Create Links to modules
    create_links(model)
    
    return model
end

function create_links(model)

    -- Create links to decoder and dynamics model
    node_encoder = 5
    node_dynamics = 44
    node_decoder = 6
    node_decoder2 = 45
    
    -- Clone Dynamics model only and share parameters
    dynamics = model:clone("weight", "bias", "gradWeight", "gradBias")
    dynamics.forwardnodes[node_encoder].data.module = nil
    dynamics.forwardnodes[node_decoder].data.module = nil
    dynamics.forwardnodes[node_decoder2].data.module = nil
    for indexNode,node in ipairs(dynamics.forwardnodes) do
         if dynamics.forwardnodes[indexNode].data.module then
            dynamics.forwardnodes[indexNode].data.module = model.forwardnodes[indexNode].data.module
        end
    end
    dynamics.forwardnodes[node_encoder].data.module = nn.Identity()
    dynamics.forwardnodes[node_decoder].data.module = nn.Identity()
    dynamics.forwardnodes[node_decoder2].data.module = nn.Identity()

    encoder = model.forwardnodes[node_encoder].data.module
    decoder = model.forwardnodes[node_decoder].data.module
    
    ae = nn.Sequential()
    ae:add(encoder)
    ae:add(decoder)
    
end

### Setup Network function

In [None]:
function setup()
    print("Creating Conv-Net.")
    model = create_network()
    params, gradParams = model:getParameters()
    
    criterion = nn.BCECriterion()
    criterion.sizeAverage = false

    KLD = nn.KLDCriterion()
    KLD.sizeAverage = false
    
    KLDist = nn.KLDistCriterion()
    KLDist.sizeAverage = false
end

function setup_load()
    
    print("Loading Conv-Net.")
    
    load_model()
    
    create_links(model)
    
    params, gradParams = model:getParameters()
    
    opt.load = true
    
    dname, fname = sys.fpath()
    opt.save = dname
    
    criterion = nn.BCECriterion()
    criterion.sizeAverage = false

    KLD = transfer_data(nn.KLDCriterion())
    KLD.sizeAverage = false
    
    KLDist = nn.KLDistCriterion()
    KLDist.sizeAverage = false
end

### Save model

In [None]:
function save_model()
    -- save/log current net
    local filename = paths.concat(opt.save, 'model/relu_single_nogravity_e2c_conv.t7')
    os.execute('mkdir -p ' .. paths.dirname(filename))
    if paths.filep(filename) then
        os.execute('mv ' .. filename .. ' ' .. filename .. '.old')
    end
    -- print('<trainer> saving network to '..filename)
    torch.save(filename, {model, opt, optim_config, train_err, test_err})
end

function load_model()
    model, opt, optim_config, train_err, test_err = unpack(torch.load('model/relu_single_nogravity_e2c_conv.t7'))
end

# Initialize Network

In [None]:
print("Network parameters:")
print(opt)

if opt.load then
    setup_load()
else
    setup()
    optim_config = { learningRate = -opt.learningRate,
                     beta2 = 0.9
                    }
    train_err = {}
    test_err = {}
end

epoch = #train_err

# Test Function

In [None]:
function run_test(dataset) 

    local err_cur = 0
    local err_next = 0
    
    local loss = nn.BCECriterion()

    for idx = 2, dataset.x:size(1)-1 do

        local x_prev = dataset.x[idx-1]
        local x_cur = dataset.x[idx]
        local x_next = dataset.x[idx+1]
        local batch_u = dataset.u[idx]
        
        local batch_x = torch.cat(x_prev,x_cur)
        local batch_y = torch.cat(x_cur,x_next)

        -- Forward
        local z_t, x_t, z_t1, x_t1 = unpack(model:forward({batch_x, batch_u}))
        err_cur = err_cur + loss:forward(x_t, batch_x)
        err_next = err_next + loss:forward(x_t1:split(opt.img_w^2)[2], x_next)
        -- err_next = err_next + loss:forward(x_t1, batch_y)

    end
    
    return err_cur / (dataset.x:size(1) - 2), err_next / (dataset.x:size(1) - 2)
end

# Train Function

In [None]:
g_create_batch(state_train)

In [None]:
function train(dataset)

    -- epoch tracker
    epoch = epoch or 0
    
    -- load minibatch
    g_create_batch(state_train)

    -- local vars
    local err = {all=0, bce=0, bce_1=0, kl=0, kld=0}

    -- shuffle at each epoch
    local shuffle = torch.randperm(#dataset.batch):long()

    for t = 1,#dataset.batch do

        -- create mini batch
        local batch_x = dataset.batch[shuffle[t]][1]
        local batch_u = dataset.batch[shuffle[t]][2]
        local batch_y = dataset.batch[shuffle[t]][3]

        -- create closure to evaluate f(X) and df/dX
        local feval = function(x)
            
            -- get new parameters
            if x ~= params then
                params:copy(x)
            end

            -- reset gradients
            gradParams:zero()
            
            -- reset errors
            local bce_err, bce_1_err, kl_err, kld_err = 0, 0, 0, 0
            
            -- evaluate function for complete mini batch
            for i = 1,#batch_x do
                
                local enc_y = encoder:forward(batch_y[i])
                                
                local z_t, x_t, z_t1, x_t1 = unpack(model:forward({batch_x[i], batch_u[i]}))  
                                
                -- BCE x_t
                bce_err = bce_err - criterion:forward(x_t, batch_x[i])
                local d_x_t = criterion:backward(x_t, batch_x[i]):clone():mul(-1)
                
                -- KL Divergence z_t
                kl_err = kl_err + KLD:forward(z_t, batch_x[i])
                local d_z_t = KLD:backward(z_t, batch_x[i])
                
                -- BCE x_t+1
                bce_1_err = bce_1_err - criterion:forward(x_t1, batch_y[i])
                local d_x_t1 = criterion:backward(x_t1, batch_y[i]):clone():mul(-1)
                                
                -- KL Divergence z^_t+1 ~ z_t+1 
                kld_err = kld_err + KLDist:forward(z_t1, enc_y) * opt.lambda
                local d_z_t1 = KLDist:backward(z_t1, enc_y)
                d_z_t1[1]:mul(opt.lambda)
                d_z_t1[2]:mul(opt.lambda)
                             
                -- Backpropagate
                model:backward({batch_x[i], batch_u[i]}, {
                        d_z_t,
                        d_x_t,
                        d_z_t1,
                        d_x_t1
                    })     
                
            end
            
            -- Accumulate errors
            err.bce = err.bce + bce_err
            err.bce_1 = err.bce_1 + bce_1_err
            err.kl = err.kl + kl_err
            err.kld = err.kld + kld_err
            err.all = err.all + bce_err + bce_1_err + kl_err + kld_err
                        
            -- normalize gradients and f(X)
            local batcherr = (bce_err + bce_1_err + kl_err + kld_err) / #batch_x           
            gradParams:div(#batch_x)
                
            print(bce_err/#batch_x, bce_1_err/#batch_x, kl_err/#batch_x, kld_err/#batch_x)
                
            -- return f and df/dX
            return batcherr, gradParams
        end
        
        if #batch_x > 0 then
            optim.adam(feval, params, optim_config)
            -- optim.adagrad(feval, params, optim_config)
            -- optim.rmsprop(feval, params, optim_config)
        end
        
    end
    
    -- Normalise errors
    err.all = err.all / (dataset.x:size(1) - 2)
    err.bce = err.bce / (dataset.x:size(1) - 2)
    err.bce_1 = err.bce_1 / (dataset.x:size(1) - 2)
    err.kl = err.kl / (dataset.x:size(1) - 2)
    err.kld = err.kld / (dataset.x:size(1) - 2)
    
    collectgarbage()

    epoch = epoch + 1

    return err
end

# Train network

In [None]:
-- epochs to run
opt.max_epoch = 3

-- start time
local beginning_time = torch.tic()

-- iterate through epochs
for e = 1, opt.max_epoch do
    
    -- local vars
    local time = sys.clock()
    
    -- train for 1 epoch
    local err = train(state_train)
        
    train_err[#train_err+1] = err
    
    -- time taken
    time = sys.clock() - time
    
    -- display stats
    if (epoch) % 1 == 0 then
        
        local since_beginning = g_d(torch.toc(beginning_time) / 60)
        print('epoch=' .. (epoch) ..
          ', Train err=' .. g_f3(train_err[#train_err].all) ..
          ', bce=' .. g_f3(train_err[#train_err].bce) ..
          ', bce_1=' .. g_f3(train_err[#train_err].bce_1) ..
          ', kl=' .. g_f3(train_err[#train_err].kl) ..
          ', kld=' .. g_f3(train_err[#train_err].kld) ..
          -- ', Test err=' .. g_f5(test_err[#test_err]) ..
          ', t/epoch = ' .. g_f3(time) .. ' sec' ..
          ', since beginning = ' .. since_beginning .. ' mins.')

        if (epoch) % 1 == 0 then
            save_model()
        end
    end
end

### Plot Performance

In [None]:
function get_error(err, criterion)
    local criterion = criterion or 'all'
    local arr = torch.zeros(#err)
    for i=1,#err do arr[i] = -err[i][criterion] end    
    return arr
end

In [None]:
colors = {'blue', 'green', 'red', 'purple', 'orange', 'magenta', 'cyan'}
plot = Plot()
plot:title(string.format('Neural Net Performance in %d epochs', #train_err))
plot = plot:line(torch.range(1,#train_err), get_error(train_err,'all'), colors[1], 'All')
plot = plot:line(torch.range(1,#train_err), get_error(train_err,'bce'), colors[2], 'BCE')
plot = plot:line(torch.range(1,#train_err), get_error(train_err,'bce_1'), colors[3], 'BCE_1')
plot = plot:line(torch.range(1,#train_err), get_error(train_err,'kl'), colors[4], 'KL')
plot = plot:line(torch.range(1,#train_err), get_error(train_err,'kld'), colors[5], 'KLD')
plot:legend(true):redraw()