# C-SWM 

In [1]:
using HDF5
using Knet
using Statistics: mean,std
using Random

In [2]:
atype = KnetArray{Float32}

SAVE_FOLDER = "./checkpoints"
NUM_STEPS = 1
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
BATCH_SIZE = 100
SEED = 0
NUM_OBJECTS = 5

5

In [3]:
function loadh5file(DATASET_PATH)
    f_e = h5open(DATASET_PATH,"r")
    dict = read(f_e)
    close(f_e)
    return dict
end

loadh5file (generic function with 1 method)

In [4]:
struct StateTransitionDataset
    """Create dataset of (o_t, a_t, o_{t+1}) transitions from replay buffer."""
   
    experience_buffer;
    # Build table for conversion between linear idx -> episode/step idx
    idx2episode;
    
    #Container to hold total number of steps
    num_steps;
    
    #Read array
    batch_idxs;
    
end

In [5]:
function buildDataset(DATASET_PATH, d_shuffle)
    
    experience_buffer = loadh5file(DATASET_PATH)   
    step = 0
    
    println("Dataset loaded. Building dataset indexing.")
    
    idx2episode = []
    
    
    for ep in 1:length(experience_buffer)
        
        ep_key = string(ep-1)
        num_steps = length(experience_buffer[ep_key]["action"])
        
        for i in 1:num_steps
           
            push!(idx2episode,(ep_key,i))
            
        end 
        
        step += num_steps
        
    end
         
    batch_idxs = collect(1:step)
    
    if d_shuffle
        batch_idxs = shuffle(batch_idxs)
    end
        
    
    return  StateTransitionDataset(experience_buffer,idx2episode,step, batch_idxs)
    
end

buildDataset (generic function with 1 method)

In [6]:
dtrn = buildDataset(TRAIN_DATASET_PATH,true);

Dataset loaded. Building dataset indexing.


In [7]:
dtrn.experience_buffer["1"]["obs"]

50×50×3×100 Array{Float32,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0     
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.301961  0.301961  0.301961
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.301961  0.301961  0.301961
 0.0  0.0  0.0  0.0

In [8]:
TOTAL_SET_SIZE = dtrn.num_steps
net_threshold = TOTAL_SET_SIZE ÷ BATCH_SIZE

1000

In [9]:
function getitem(s,idx)
    
    ep_key, step = s.idx2episode[idx]
    obs = s.experience_buffer[ep_key]["obs"][:,:,:,step]
    action = s.experience_buffer[ep_key]["action"][step]
    next_obs = s.experience_buffer[ep_key]["next_obs"][:,:,:,step]
    
    return obs,action,next_obs
    
end

(s::StateTransitionDataset)(idx) = getitem(s,idx)

In [10]:
function prepareBatch(s,idx_1, idx_2)
    """Lazy loader to GPU."""
    
    minibatch = s.batch_idxs[idx_1:idx_2]
    minibatch_batch_size = size(minibatch,1)
    
    #Read
    b_obs = zeros(50,50,3,minibatch_batch_size)
    b_next_obs = zeros(50,50,3,minibatch_batch_size)
    b_action = zeros(minibatch_batch_size)
    
    for i in 1:length(minibatch_batch_size)
        
        idx = minibatch[i]
        obs, action, next_obs = s(idx) 
        
        #Insert obs
        b_obs[:,:,:,i] = obs
        
        #Assign action
        b_action[i] = action
        
        #Insert next_obs
        b_next_obs[:,:,:,i] =  next_obs
        
    end
    
    return atype(b_obs), b_action, atype(b_next_obs)
    
end

prepareBatch (generic function with 1 method)

In [11]:
function batchIdxHelper(batch_idx)

    net_threshold = TOTAL_SET_SIZE ÷ BATCH_SIZE
    remainder = TOTAL_SET_SIZE % BATCH_SIZE
    
    idx_1 = -1
    idx_2 = -1
    
    if batch_idx <= net_threshold
        
        idx_1 = 1 + BATCH_SIZE*(batch_idx-1)
        idx_2 = BATCH_SIZE*batch_idx
        
    elseif batch_idx == net_threshold
        
        idx_1 = 1 + BATCH_SIZE*(batch_idx-1)
        idx_2 = TOTAL_SET_SIZE
        
    end

    return idx_1, idx_2
    
end

batchIdxHelper (generic function with 1 method)

In [12]:
idx_1, idx_2 = batchIdxHelper(20)

(1901, 2000)

In [13]:
obs, b_action, next_obs = prepareBatch(dtrn,idx_1,idx_2);

In [14]:
println(size(obs))
println(size(b_action))
println(size(next_obs))

(50, 50, 3, 100)
(100,)
(50, 50, 3, 100)


In [15]:
function getBatch(dataset,idx)
    
    idx_1, idx_2 = batchIdxHelper(idx)
    
    if idx_1 == -1 && idx_2 == -1
    
        println("Invalid batch index") 
        return
        
    end
    
    obs, action, next_obs = prepareBatch(dataset,idx_1,idx_2)
    
    return obs, action, next_obs
end

getBatch (generic function with 1 method)

In [16]:
obs, action, next_obs = getBatch(dtrn,5)

(K32(50,50,3,100)[0.0⋯], [12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], K32(50,50,3,100)[0.0⋯])

In [17]:
println(size(obs))
println(size(b_action))
println(size(next_obs))

(50, 50, 3, 100)
(100,)
(50, 50, 3, 100)


In [18]:
function toOneHot(action_dim, idx)
   
    vec = zeros(action_dim)
    
    vec[idx] = 1.0
    
    return atype(vec)
    
end

toOneHot (generic function with 1 method)

In [19]:
toOneHot(20,3)

20-element KnetArray{Float32,1}:
 0.0
 0.0
 1.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

# C-SWM

### EncoderCNNSmall

TODO: Add bias

In [20]:
struct EncoderCNNSmall
   
    weights
    bn_vars
    act_fn
    act_fn_hid
    
end

function initEncoderCNNSmall(input_dim, hidden_dim, num_objects, act_fn, act_fn_hid)
    
    weights = Any[param(10,10,3,hidden_dim), param(1,1,hidden_dim,num_objects)]
    bn_vars = Any[bnmoments(), atype(bnparams(hidden_dim))]
    
    return EncoderCNNSmall(weights, bn_vars, act_fn, act_fn_hid)
    
end

function (e_cnn::EncoderCNNSmall)(x)
    
    #w(w_x,w_y,in_ch,out_ch)
    x1 = conv4(e_cnn.weights[1],x; stride = 10)
    x2 = e_cnn.act_fn_hid.(batchnorm(x1, e_cnn.bn_vars[1], e_cnn.bn_vars[2]))
    y = e_cnn.act_fn.(conv4(e_cnn.weights[2],x2;stride=1))
    
    return y
    
end

In [21]:
obj_extractor = initEncoderCNNSmall(3,64,6,sigm,relu)

EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,64)), P(KnetArray{Float32,4}(1,1,64,6))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(128)[1.0⋯]], Knet.sigm, NNlib.relu)

In [22]:
masks = obj_extractor(obs)

5×5×6×100 KnetArray{Float32,4}:
[:, :, 1, 1] =
 0.5       0.494915  0.5  0.481787  0.5     
 0.514914  0.5       0.5  0.5       0.5     
 0.471963  0.5       0.5  0.5       0.494682
 0.5       0.5       0.5  0.5       0.5     
 0.5       0.5       0.5  0.5       0.5     

[:, :, 2, 1] =
 0.5       0.431381  0.5  0.451126  0.5     
 0.473642  0.5       0.5  0.5       0.5     
 0.462689  0.5       0.5  0.5       0.488769
 0.5       0.5       0.5  0.5       0.5     
 0.5       0.5       0.5  0.5       0.5     

[:, :, 3, 1] =
 0.5       0.517078  0.5  0.491744  0.5     
 0.478384  0.5       0.5  0.5       0.5     
 0.491275  0.5       0.5  0.5       0.504319
 0.5       0.5       0.5  0.5       0.5     
 0.5       0.5       0.5  0.5       0.5     

[:, :, 4, 1] =
 0.5       0.478908  0.5  0.45986  0.5     
 0.473645  0.5       0.5  0.5      0.5     
 0.485945  0.5       0.5  0.5      0.506316
 0.5       0.5       0.5  0.5      0.5     
 0.5       0.5       0.5  0.5      0.5     

[:, :, 5,

### LayerNorm Layer (Taken from: https://github.com/denizyuret/Knet.jl/issues/492)

In [23]:
struct LayerNorm; a; b; ϵ; end

function LayerNorm(dmodel; eps=1e-6)
    a = param(dmodel; init=ones)
    b = param(dmodel; init=zeros)
    return LayerNorm(a, b, eps)
end

function (l::LayerNorm)(x, o...)
    μ = mean(x,dims=1)
    σ = std(x,mean=μ,dims=1)
    return l.a .* (x .- μ) ./ (σ .+ l.ϵ) .+ l.b                                                         
end

In [24]:
struct EncoderMLP
    
    weights
    biases
    layer_norm
    act_fn

end


function initEncoderMLP(input_dim, hidden_dim, output_dim, num_objects, act_fn)
    
    weights = [param(hidden_dim,input_dim),param(hidden_dim, hidden_dim),param(output_dim,hidden_dim )]
    biases =  [param0(hidden_dim), param0(hidden_dim), param0(output_dim)]
    layer_norm = LayerNorm(hidden_dim)
    
    return EncoderMLP(weights, biases, layer_norm, act_fn)
    
end

function (e_mlp::EncoderMLP)(x)
    
    dims = size(x)
    println(dims)
    x0 = reshape(x, dims[1]*dims[2], dims[3]*dims[4])
    println(size(x0))
    x1 = e_mlp.act_fn.(e_mlp.weights[1] * x0 .+ e_mlp.biases[1])
    println(size(x1))
    x2 = e_mlp.weights[2]* x1 .+ e_mlp.biases[2]
    println(size(x2))
    x3 = e_mlp.act_fn.(e_mlp.layer_norm(x2))
    println(size(x3))
    x4 = e_mlp.act_fn.(e_mlp.weights[3]*x3 .+ e_mlp.biases[3])
    x5 = reshape(x4, :, dims[3],dims[4])
    
    return x5
    
end

In [25]:
# Input dim = 25
# Hidden dim = 512
# Output dim = 2

obj_encoder = initEncoderMLP(25,512,2,5,relu)

EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu)

In [26]:
o = obj_encoder(masks)

(5, 5, 6, 100)
(25, 600)
(512, 600)
(512, 600)
(512, 600)


2×6×100 KnetArray{Float32,3}:
[:, :, 1] =
 0.551048  0.538736  0.547256  0.539286  0.567929  0.539596
 0.546215  0.502752  0.538293  0.509351  0.526404  0.516437

[:, :, 2] =
 0.557341  0.557341  0.557341  0.557341  0.557341  0.557341
 0.53802   0.53802   0.53802   0.53802   0.53802   0.53802 

[:, :, 3] =
 0.557341  0.557341  0.557341  0.557341  0.557341  0.557341
 0.53802   0.53802   0.53802   0.53802   0.53802   0.53802 

...

[:, :, 98] =
 0.557341  0.557341  0.557341  0.557341  0.557341  0.557341
 0.53802   0.53802   0.53802   0.53802   0.53802   0.53802 

[:, :, 99] =
 0.557341  0.557341  0.557341  0.557341  0.557341  0.557341
 0.53802   0.53802   0.53802   0.53802   0.53802   0.53802 

[:, :, 100] =
 0.557341  0.557341  0.557341  0.557341  0.557341  0.557341
 0.53802   0.53802   0.53802   0.53802   0.53802   0.53802 

In [None]:
struct EdgeMLP
    
    weights
    biases
    layer_norm
    act_fn
  
end


function initEdgeMLP(input_dim, hidden_dim, act_fn)
    
    weights = [param(hidden_dim,input_dim*2),param(hidden_dim, hidden_dim),param(hidden_dim,hidden_dim)]
    biases =  [param0(hidden_dim), param0(hidden_dim), param0(hidden_dim)]    
    layer_norm = LayerNorm(hidden_dim)
    
end


function (edge_mlp::EdgeMLP)(x)
    
    x1 = e_mlp.act_fn.(e_mlp.weights[1] * x .+ e_mlp.biases[1])
    println(size(x1))
    x2 = e_mlp.weights[2]* x1 .+ e_mlp.biases[2]
    println(size(x2))
    x3 = e_mlp.act_fn.(e_mlp.layer_norm(x2))
    println(size(x3))
    x4 = e_mlp.act_fn.(e_mlp.weights[3]*x3 .+ e_mlp.biases[3])
    
    return x4
    
end


In [None]:
struct NodeMLP
    
    weights
    biases
    layer_norm
    act_fn
    
end

function initNodeMLP(node_input_dim, hidden_dim, act_fn, input_dim)
    
    weights = [param(hidden_dim,node_input_dim),param(hidden_dim, hidden_dim),param(input_dim,hidden_dim )]
    biases =  [param0(hidden_dim), param0(hidden_dim), param0(input_dim)]    
    layer_norm = LayerNorm(hidden_dim)
    
    return NodeMLP(weights, biases, layer_norm, act_fn)
    
end


function (node_mlp::NodeMLP)(x)
    
    x1 = e_mlp.act_fn.(e_mlp.weights[1] * x .+ e_mlp.biases[1])
    println(size(x1))
    x2 = e_mlp.weights[2]* x1 .+ e_mlp.biases[2]
    println(size(x2))
    x3 = e_mlp.act_fn.(e_mlp.layer_norm(x2))
    println(size(x3))
    x4 = e_mlp.act_fn.(e_mlp.weights[3]*x3 .+ e_mlp.biases[3])
    
    return
    
end

In [None]:
struct TransitionGNN
    
    edge_mlp
    node_mlp
    
    ignore_action
    copy_action
    
    edge_list
    batch_size
    
end

function initTransitionGNN(input_dim, hidden_dim, action_dim, num_objects, ignore_action, copy_action, act_fn)
    
    if ignore_action
        
        action_dim = 0
        
    end
    
    #Edge MLP
    edge_mlp = initEdgeMLP(input_dim, hidden_dim, act_fn)
    node_input_dim = hidden_dim + input_dim + action_dim
    
    #Node MLP
    node_mlp = initNodeMLP(node_input_dim, hidden_dim, act_fn, input_dim)
    
    edge_list = Any[]
    batch_size = 0
    
    return TransitionGNN(edge_mlp, node_mlp, ignore_action, copy_action, edge_list, batch_size)
    
end


function get_edge_list_fully_connected()
    
end

function (t_gnn::TransitionGNN)(x)
    
    
    #Todo forward prop
    dimensions = size(x)

    
    
end