# C-SWM 

In [1]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [2]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0;

In [3]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

## Load the pretrained model

In [4]:
model = Knet.load("model_2dshapes_lr.jld2", "model")

ContrastiveSWM(EncoderCNNSmall(Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Param{KnetArray{Float32,4}}[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Any[Knet.BNMoments(0.1f0, K32(1,1,32,1)[-0.24739124⋯], K32(1,1,32,1)[0.5332115⋯], zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(Knet

In [5]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_eval.h5"
EVAL_BATCH_SIZE = 100
STEP_SIZE = 5

5

In [6]:
dtst = buildMultiStepPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE, STEP_SIZE);

In [7]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = model(obs,action[:,1])
    
    for i in 2:STEP_SIZE
        
        pred_state = pred_state + model.gnn(pred_state,action[:,i])
        
    end
    
    pred_state = Array{Float32}(pred_state)
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,10000])
next_states = repeat(next_states, outer=[1,10000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = hcat(diag_elements, pairwise_distance_matrix)


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:10000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches
Processed 20 batches
Processed 30 batches
Processed 40 batches
Processed 50 batches
Processed 60 batches
Processed 70 batches
Processed 80 batches
Processed 90 batches
Processed 100 batches


10000×10001 Array{Int64,2}:
 7675      1     2  4355  3048  9196  …  8483  2918  5816   740  4121  9315
    1      3  5979  1099  5029  6743     9220  4527  8404  5395  6776  8579
    1      4  7942  2040  5992  7744     2366  8652  2793   244  4265  4917
    1      5  7674  4424  8945  8861     5473  4275  6039  6005  2329  5219
    1      6  9560  8208  6840  5706     1728  6542  4593  9097   785  9742
 5326      1     7  2536  3863  2704  …  6673  1197  1774  8191  4539  3351
 9685      1     8  7750   547  5005     8810  9825  5759  1267  9771  5835
    1      9  4195  1882  8082  9924     7406  4655  6835  8974  4781  2735
  908   8726  1525     1    10  5025     2725  2741  8817  4396  4087  2460
    1     11   158  7368  7466  8945     5219  2805  1301  9808  6039  6005
    1     12  5755  7058  4736  5339  …  8539  3310  6516  4500  9787  8870
 6915      1    13  5713  6387  2663      471  5400  1646  2856  1358  2875
    1     14  1295  3126  2631  3893      612  5767  4275  8

In [8]:
num_matches = sum(labels .== indices[:,1])
println("Hits @ 1: ", num_matches/num_samples)

Hits @ 1: 0.9094


In [9]:
num_matches = sum(labels .== indices[:,1:5])
println("Hits @ 5: ", num_matches/num_samples)

Hits @ 5: 0.9845


In [10]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.9447837728883008
