# C-SWM 

In [1]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [2]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0;

In [3]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

## Load the pretrained model

In [4]:
model = Knet.load("model_mvmnist_moveone.jld2", "model")

ContrastiveSWM(EncoderCNNLarge(Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(3,3,3,32)), P(KnetArray{Float32,4}(3,3,32,32)), P(KnetArray{Float32,4}(3,3,32,32)), P(KnetArray{Float32,4}(3,3,32,5))], Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,2500)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], NNlib.relu), NodeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Floa

In [5]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/moving_mnist/movingmnist_tst_moveone.h5"
EVAL_BATCH_SIZE = 100
STEP_SIZE = 10

10

In [6]:
dtst = buildMovingMNISTMultiStepPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE,STEP_SIZE);

In [7]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    pred_state = model(obs,action[:,:,1]) + model(obs)
    
    for i in 2:STEP_SIZE

        pred_state = pred_state + model.gnn(pred_state,action[:,:,i])
        
    end
    
    pred_state = Array{Float32}(pred_state)
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,1000])
next_states = repeat(next_states, outer=[1,1000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = hcat(diag_elements, pairwise_distance_matrix)


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:1000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches


1000×1001 Array{Int64,2}:
 637   823  624  186    1    2  273  …  168  971  983  625  149  542  209
 599   955  715   69    1    3  915     333  195  240  731  577  353  391
 508   766  283  997  433  520  775     531  264  335   77  526   49  634
   1     5   60  766  639  702  117     364  949  625  526  101   77  971
   1     6  833   83  521  496  243     624  731  403  656  142  634  531
 438   648  612  949  546  149  209  …  457  999  577  552  189  731  624
   1     8   27  521  872  207  833     531  731  634  656  403  380  446
 511   736    1    9  312  850   67     478  209  526  971  625  983  149
 396   340  441  920  974    1   10     511  772  999  724  189  731  624
  28    87  772  533  724  200  397     971  612  168  526  983  209  149
 878   266  779  609   98  262  599  …  157  353  581  630  316  415  391
 353   520  258  680  389  377  445      74  531  142  634  421  101  526
 238     1   14  959  675  350  171     971  612  387  531   76  962  209
   ⋮        

In [8]:
num_matches = sum(labels .== indices[:,1])
println("Hits @ 1: ", num_matches/num_samples)

Hits @ 1: 0.298


In [9]:
num_matches = sum(labels .== indices[:,1:5])
println("Hits @ 5: ", num_matches/num_samples)

Hits @ 5: 0.667


In [10]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.46159227792176416
