# C-SWM 

In [1]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [2]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0

1.0

In [3]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

In [4]:
model = initContrastiveSWMSmall(input_ch, hidden_dim, num_objects, embedding_dim, action_dim, sigma, hinge)

ContrastiveSWM(EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), NodeMLP(Para

## Training Part

In [5]:
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
TRAIN_BATCH_SIZE = 1024
dtrn = buildStateTransitionDataset(TRAIN_DATASET_PATH, true, TRAIN_BATCH_SIZE);

Dataset loaded. Building dataset indexing...
Done.


In [6]:
function initopt!(model::ContrastiveSWM)
    
    for par in params(model)
        par.opt = Adam(;lr=0.0005, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
        println(par)
    end
    end;
initopt!(model)

P(KnetArray{Float32,4}(10,10,3,32))
P(KnetArray{Float32,4}(1,1,32,5))
P(KnetArray{Float32,4}(1,1,32,1))
P(KnetArray{Float32,4}(1,1,5,1))
P(KnetArray{Float32,2}(512,25))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,4))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,518))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))


In [7]:
#Verbose after x batches
VERBOSE =  10

#Define number of epochs
NUM_EPOCHS = 100

println("Starting training...")

for i in 1:NUM_EPOCHS
    
    avg_loss = 0.0
    it = 0
    for  (k, (obs, action, next_obs)) in enumerate(dtrn)

        #Train by using contrastive loss
        J = @diff model(obs,action,next_obs)
        
        for par in params(model)
            g = grad(J, par)
            update!(value(par), g, par.opt)
        end
        
        batch_size = size(obs,4)

        if k % VERBOSE == 0
            
            println("Epoch: ", i , ", Iter: " , k*batch_size, "/", dtrn.num_steps, ", Loss: ", value(J))

        end
        
        avg_loss += value(J)
        it = k
        
    end
    
    avg_loss /= it
    
    println("Avg loss: " , avg_loss)
end

#dtrn = nothing
#Knet.gc()
#GC.gc()

Starting training...
Epoch: 1, Iter: 10240/100000, Loss: 0.42359978
Epoch: 1, Iter: 20480/100000, Loss: 0.29025236
Epoch: 1, Iter: 30720/100000, Loss: 0.20908357
Epoch: 1, Iter: 40960/100000, Loss: 0.1706932
Epoch: 1, Iter: 51200/100000, Loss: 0.1359646
Epoch: 1, Iter: 61440/100000, Loss: 0.11933155
Epoch: 1, Iter: 71680/100000, Loss: 0.11124304
Epoch: 1, Iter: 81920/100000, Loss: 0.10445408
Epoch: 1, Iter: 92160/100000, Loss: 0.09699994
Avg loss: 0.5315361105136036
Epoch: 2, Iter: 10240/100000, Loss: 0.09638713
Epoch: 2, Iter: 20480/100000, Loss: 0.08142391
Epoch: 2, Iter: 30720/100000, Loss: 0.075552
Epoch: 2, Iter: 40960/100000, Loss: 0.074401855
Epoch: 2, Iter: 51200/100000, Loss: 0.07540923
Epoch: 2, Iter: 61440/100000, Loss: 0.06868971
Epoch: 2, Iter: 71680/100000, Loss: 0.0648802
Epoch: 2, Iter: 81920/100000, Loss: 0.068197325
Epoch: 2, Iter: 92160/100000, Loss: 0.06900366
Avg loss: 0.07468998228612635
Epoch: 3, Iter: 10240/100000, Loss: 0.06557751
Epoch: 3, Iter: 20480/100000, 

In [8]:
dtrn = nothing
Knet.gc()
GC.gc()

In [10]:
Knet.save("model_2dshapes_lr.jld2", "model", model)

## Evaluation

In [None]:
model = Knet.load("model_2dshapes_lr.jld2", "model")

In [11]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_eval.h5"
EVAL_BATCH_SIZE = 100

100

In [12]:
dtst = buildPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE);

In [None]:
obs, action,next_obs = first(dtst)

In [13]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = Array{Float32}(model(obs,action))
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,10000])
next_states = repeat(next_states, outer=[1,10000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = hcat(diag_elements, pairwise_distance_matrix)


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:10000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches
Processed 20 batches
Processed 30 batches
Processed 40 batches
Processed 50 batches
Processed 60 batches
Processed 70 batches
Processed 80 batches
Processed 90 batches
Processed 100 batches


10000×10001 Array{Int64,2}:
 1      2  4355  8965  7675  9196  …  8159  1966  9199  9315   740  6317
 1      3  1889  9950  8907  5423     8604  5395  6355  8400  8579  8404
 1      4  4659  7942  5277  7009     4265  7183  7865  3606   244  4821
 1      5  6116   909  6402  4143     5473   612  7376  9736  6204  1936
 1      6  9560  8208  5517  2771     9175  7741  3335  9742  8877  3265
 1      7  3358  5548  8098  8483  …  1774  2792  6054  5338  8832  4539
 1      8  3234  4958  5005  9685     8052  6469  5459  2035  6786  1267
 1      9  9993  2864  8466  2183     2615  6202  4693  7354  4781  2735
 1     10  7045  7532  8034  3476     6209  6340  8000  8817  5301   392
 1     11  8945  1587  9878  5156     1301  7085  9736  4612   612  1936
 1     12  3919  6580  3036  3455  …  8870  7327  9319  3259  1741  9526
 1     13  1305  7733  6270  2377     7236  5400  9337  1637  6762  2856
 1     14  4812  9072  9926  1189     6728  5767  2904  7376  6204  1936
 ⋮                     

In [14]:
num_matches = sum(labels .== indices[:,1])
hits_at_1 += num_matches
println("Hits @ 1: ", hits_at_1/num_samples)

Hits @ 1: 0.9892


In [15]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.99455


In [17]:
num_matches = sum(labels .== indices[:,1:5])
println("Hits @ 5: ", num_matches/num_samples)

Hits @ 5: 1.0
