# C-SWM 

In [2]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [9]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0

1.0

In [41]:
model = initContrastiveSWMSmall(input_ch, hidden_dim, num_objects, embedding_dim, action_dim, sigma, hinge)

ContrastiveSWM(EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), NodeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,518)), P(KnetArray{Floa

## Training Part

In [42]:
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
TRAIN_BATCH_SIZE = 1024
dtrn = buildStateTransitionDataset(TRAIN_DATASET_PATH, true, TRAIN_BATCH_SIZE);

Dataset loaded. Building dataset indexing...
Done.


In [43]:
obs, action, next_obs = first(dtrn)

(K32(50,50,3,1024)[0.0⋯], [9, 3, 3, 13, 3, 16, 20, 8, 20, 8  …  18, 18, 19, 3, 6, 14, 5, 1, 20, 10], K32(50,50,3,1024)[0.0⋯])

In [44]:
s = model.obj_extractor(obs);

In [45]:
e = model.obj_encoder(s)

2×5×1024 KnetArray{Float32,3}:
[:, :, 1] =
 0.529608  0.526878  0.538434  0.532567  0.532921
 0.689851  0.650228  0.699094  0.662352  0.660854

[:, :, 2] =
 0.514406  0.490889  0.551916  0.527977  0.509689
 0.700593  0.649146  0.708859  0.675238  0.67035 

[:, :, 3] =
 0.570579  0.506215  0.568469  0.516262  0.52698 
 0.718746  0.651417  0.709324  0.676492  0.672193

...

[:, :, 1022] =
 0.513028  0.576044  0.512759  0.557381  0.564003
 0.640903  0.721222  0.642544  0.685149  0.709116

[:, :, 1023] =
 0.539897  0.55011   0.52866   0.523404  0.559933
 0.674866  0.693411  0.679383  0.674325  0.704763

[:, :, 1024] =
 0.519916  0.515605  0.537421  0.516599  0.529854
 0.669387  0.681358  0.678635  0.650738  0.695294

In [46]:
function initopt!(model::ContrastiveSWM)
    
    for par in params(model)
        par.opt = Adam(;lr=0.005)
        println(par)
    end
end

initopt! (generic function with 1 method)

In [47]:
initopt!(model)

P(KnetArray{Float32,4}(10,10,3,32))
P(KnetArray{Float32,4}(1,1,32,5))
P(KnetArray{Float32,2}(512,25))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,4))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,518))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))


In [48]:
#Verbose after x batches
VERBOSE =  20

#Define number of epochs
NUM_EPOCHS = 10

println("Starting training...")

for i in 1:NUM_EPOCHS
    
    avg_loss = 0.0
    it = 0
    for  (k, (obs, action, next_obs)) in enumerate(dtrn)

        #Train by using contrastive loss
        J = @diff model(obs,action,next_obs)
        
        for par in params(model)
            g = grad(J, par)
            update!(value(par), g, par.opt)
        end
        
        batch_size = size(obs,4)

        if k % VERBOSE == 0
            
            println("Epoch: ", i , ", Iter: " , k*batch_size, "/", dtrn.num_steps, ", Loss: ", value(J))

        end
        
        avg_loss += value(J)
        it = k
        
    end
    
    avg_loss /= it
    
    println("Avg loss: " , avg_loss)
end

dtrn = nothing
Knet.gc()
GC.gc()

Starting training...
Epoch: 1, Iter: 20480/100000, Loss: 8.886596e-6
Epoch: 1, Iter: 40960/100000, Loss: 1.1911742e-6
Epoch: 1, Iter: 61440/100000, Loss: 2.1982966e-7
Epoch: 1, Iter: 81920/100000, Loss: 1.0755513e-7
Avg loss: 0.0009590856510199726
Epoch: 2, Iter: 20480/100000, Loss: 7.105548e-8
Epoch: 2, Iter: 40960/100000, Loss: 5.6801078e-8
Epoch: 2, Iter: 61440/100000, Loss: 5.125508e-8
Epoch: 2, Iter: 81920/100000, Loss: 5.2524427e-8
Avg loss: 6.104858056877458e-8
Epoch: 3, Iter: 20480/100000, Loss: 4.573629e-8
Epoch: 3, Iter: 40960/100000, Loss: 4.0252253e-8
Epoch: 3, Iter: 61440/100000, Loss: 3.7990397e-8
Epoch: 3, Iter: 81920/100000, Loss: 4.1271853e-8
Avg loss: 4.300886461176551e-8
Epoch: 4, Iter: 20480/100000, Loss: 3.7364224e-8
Epoch: 4, Iter: 40960/100000, Loss: 3.421092e-8
Epoch: 4, Iter: 61440/100000, Loss: 3.227399e-8
Epoch: 4, Iter: 81920/100000, Loss: 3.5455002e-8
Avg loss: 3.582301491412847e-8
Epoch: 5, Iter: 20480/100000, Loss: 3.2199114e-8
Epoch: 5, Iter: 40960/10000

In [49]:
dtrn = nothing
Knet.gc()
GC.gc()

In [50]:
Knet.save("model.jld2", "model", model)

## Evaluation

In [14]:
model = Knet.load("model.jld2", "model")

ContrastiveSWM(EncoderCNNSmall(Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[Knet.BNMoments(0.1f0, K32(1,1,32,1)[0.50246775⋯], K32(1,1,32,1)[1.1001533⋯], zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-6), NNlib.relu), NodeMLP(Param{KnetArray

In [15]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_eval.h5"
EVAL_BATCH_SIZE = 100

100

In [16]:
dtst = buildPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE);

In [17]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = model(obs,action)
    next_state = model(next_obs)
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
        
    push!(pred_states, Array{Float32}(pred_state))
    push!(next_states, Array{Float32}(next_state))
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,10000])
next_states = repeat(next_states, outer=[1,10000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = vcat(pairwise_distance_matrix, diag_elements')


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:10000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches
Processed 20 batches
Processed 30 batches
Processed 40 batches
Processed 50 batches
Processed 60 batches
Processed 70 batches
Processed 80 batches
Processed 90 batches
Processed 100 batches


10000×10000 Array{Int64,2}:
 4870  9759  6734  9659  6496  1676  …   115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676     3940  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676  …   115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676     3940  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676  …   115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  6230
 4870  9759  6734  9659  6496  1676      115  8095  6285  8452  4813  62

In [18]:
num_matches = sum(labels .== indices[:,1])
hits_at_1 += num_matches
println("Hits @ 1: ", hits_at_1/num_samples)

Hits @ 1: 0.0


In [12]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.003463678106726192
