# C-SWM 

In [28]:
#Libs
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
import Base: iterate, length


#Datatype
atype=KnetArray{Float32}

#Includes
include("utils.jl")
include("layernorm.jl")
include("encodercnn.jl")
include("encodermlp.jl")
include("gnn.jl")

#Params
SAVE_FOLDER = "./checkpoints"
NUM_STEPS = 1
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
BATCH_SIZE = 100
SEED = 0
NUM_OBJECTS = 5

5

In [29]:
dtrn = buildDataset(TRAIN_DATASET_PATH, true, BATCH_SIZE);

Dataset loaded. Building dataset indexing...
Done.


In [30]:
obs,action,next_obs = first(dtrn)
println(size(obs))
println(size(action))
println(size(next_obs))

(50, 50, 3, 100)
(100,)
(50, 50, 3, 100)


In [31]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0



TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), NodeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,518)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), false, false, 4, 2, nothing, 0)

In [34]:
function energy(state, action, next_state, trans)
    
    norm = 0.5 / (sigma^2)
    
    if trans
        
        diff = state - next_state
        
    else
        
        pred_trans = gnn(state,action)
        diff = state + pred_trans - next_state
        
    end
    
    
    return mean(norm*sum((diff.^2), dims=3),dims=2)
    
    
end

energy (generic function with 1 method)

In [35]:
mutable struct ContrastiveSWM
    
    obj_extractor::EncoderCNNSmall
    obj_encoder::EncoderMLP
    gnn::TransitionGNN
    
end

In [36]:
function initContrastiveSWMSmall()

    obj_extractor = initEncoderCNNSmall(input_ch, hidden_dim ÷ 16, num_objects, sigm, relu)
    obj_encoder = initEncoderMLP(25, hidden_dim, embedding_dim, num_objects, relu)
    gnn = initTransitionGNN(embedding_dim, hidden_dim, action_dim, num_objects, false, false, relu)
    
    return ContrastiveSWM(obj_extractor, obj_encoder, gnn)
    
end

initContrastiveSWMSmall (generic function with 1 method)

In [37]:
# Forward propagation with transition
function (m::ContrastiveSWM)(obs,action)
    
    # Extract objects
    objs = m.obj_extractor(obs)
    
    # Obtain embeddings
    state = m.obj_encoder(objs)
    
    #Transition    
    out = m.gnn(state,action)
    
    return out
    
end

In [38]:
# Forward propagation without transition
function (m::ContrastiveSWM)(obs)
    
    # Extract objects
    objs = m.obj_extractor(obs)
    
    # Obtain embeddings
    state = m.obj_encoder(objs)
    
    return out
    
end

In [39]:
# Contrastive loss part
function (m::ContrastiveSWM)(obs,action,next_obs)
    
    # Extract objects
    objs = obj_extractor(obs)
    
    # Obtain embeddings
    state = obj_encoder(objs)   
     
    # Sample negative state across episodes at random
    batch_size = size(obs,4)
    perm = rand(1:batch_size)
    neg_state = state[:,:,perm]
    
    # Pos loss
    pos_loss = energy(state, action, next_state,true)
    pos_loss = mean(pos_loss)
    
    println(pos_loss)
    # Neg loss
    zero_mat = zeros(size(pos_loss))
    neg_loss = max(zero_mat, hinge - energy(state,action,next_state,false))
    neg_loss = mean(neg_loss)
        
    loss = pos_loss + neg_loss
    
    
    return loss
    
end

In [40]:
model = initContrastiveSWMSmall()

ContrastiveSWM(EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), NodeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,518)), P(KnetArray{Floa

In [None]:
NUM_ITER = 10000
VERBOSE =  100

for  (k, (obs, action, next_obs)) in enumerate(dtrn)
    
    #Train by using contrastive loss
    loss = model(obs,action,next_obs)
    
    J = @diff loss 
    for par in params(model)
        g = grad(J, par)
        update!(value(par), g, par.opt)
    end
    
    if k % VERBOSE == 0
       
        println(loss)
        
    end
    
    if k > NUM_ITER
        break
    end
    
        
end