# C-SWM 

In [1]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [2]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0;

In [3]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

## Load the pretrained model

In [4]:
model = Knet.load("model_2dshapes_moving_mnist.jld2", "model")

ContrastiveSWM(EncoderCNNSmall(Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Any[Knet.BNMoments(0.1f0, K32(1,1,32,1)[-0.0828013⋯], K32(1,1,32,1)[0.10184157⋯], zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(Knet

In [5]:
#Params
EVAL_DATASET_PATH =  "/home/cagan/dev/datasets/moving_mnist/moving_mnist_grid_rgb_tst.h5"
EVAL_BATCH_SIZE = 100
STEP_SIZE = 10

10

In [6]:
dtst = buildMultiStepPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE, STEP_SIZE);

In [11]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = model(obs,action[:,1])
    
    for i in 2:STEP_SIZE
        
        pred_state = pred_state + model.gnn(pred_state,action[:,i])
        
    end
    
    pred_state = Array{Float32}(pred_state)
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,1000])
next_states = repeat(next_states, outer=[1,1000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = hcat(diag_elements, pairwise_distance_matrix)


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:1000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches


1000×1001 Array{Int64,2}:
 757  330  978  223  652  750   73  753  …  616  798  592  692  613  593  612
 757  760  220  222  223  601  761  979     280  592  593  798  616  613  612
 978  287  222  220   48  790  221  223     520  932  692  695  593  613  612
 804  423   50  469   49   74    7  422     183  504  375  904  518  903  548
 978   50   48   49  287  804   71   46     668  182  504  784  184  183  548
   1    7  423  296  846  431  422  759  …  692  684  375  183  858  182  181
 652  978   49  921  222  223  218   48     692  668  195  548  667  611  612
 652  222   48  223  221   49  921  978     373  196  548  195  667  668  612
 652  222  978  223  790  221  921  750     613  196  195  611  667  668  612
 652   49  921   48  741  750  223   66      87  548  612  196  668  195  163
 652  921   49  918  653  920  651  947  …  668  548  196  487  348  612  195
 652  654  145  653  144  649  143  218     767  612  488   37  610  348  487
 652  915  336  918  144  750  921  22

In [12]:
num_matches = sum(labels .== indices[:,1])
println("Hits @ 1: ", num_matches/num_samples)

Hits @ 1: 0.06


In [13]:
num_matches = sum(labels .== indices[:,1:5])
println("Hits @ 5: ", num_matches/num_samples)

Hits @ 5: 0.162


In [14]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.1289516131819102
