# C-SWM 

In [1]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [2]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0

1.0

In [3]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

In [4]:
model = initContrastiveSWMSmall(input_ch, hidden_dim, num_objects, embedding_dim, action_dim, sigma, hinge)

ContrastiveSWM(EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), NodeMLP(Para

## Training Part

In [5]:
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
TRAIN_BATCH_SIZE = 1024
dtrn = buildStateTransitionDataset(TRAIN_DATASET_PATH, true, TRAIN_BATCH_SIZE);

Dataset loaded. Building dataset indexing...
Done.


In [6]:
function initopt!(model::ContrastiveSWM)
    
    for par in params(model)
        par.opt = Adam(;lr=0.005, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
        println(par)
    end
    end;
initopt!(model)

P(KnetArray{Float32,4}(10,10,3,32))
P(KnetArray{Float32,4}(1,1,32,5))
P(KnetArray{Float32,4}(1,1,32,1))
P(KnetArray{Float32,4}(1,1,5,1))
P(KnetArray{Float32,2}(512,25))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,4))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,518))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))


In [7]:
#Verbose after x batches
VERBOSE =  10

#Define number of epochs
NUM_EPOCHS = 100

println("Starting training...")

for i in 1:NUM_EPOCHS
    
    avg_loss = 0.0
    it = 0
    for  (k, (obs, action, next_obs)) in enumerate(dtrn)

        #Train by using contrastive loss
        J = @diff model(obs,action,next_obs)
        
        for par in params(model)
            g = grad(J, par)
            update!(value(par), g, par.opt)
        end
        
        batch_size = size(obs,4)

        if k % VERBOSE == 0
            
            println("Epoch: ", i , ", Iter: " , k*batch_size, "/", dtrn.num_steps, ", Loss: ", value(J))

        end
        
        avg_loss += value(J)
        it = k
        
    end
    
    avg_loss /= it
    
    println("Avg loss: " , avg_loss)
end

#dtrn = nothing
#Knet.gc()
#GC.gc()

Starting training...
Epoch: 1, Iter: 10240/100000, Loss: 1.2729872
Epoch: 1, Iter: 20480/100000, Loss: 0.33176392
Epoch: 1, Iter: 30720/100000, Loss: 0.21588787
Epoch: 1, Iter: 40960/100000, Loss: 0.12814604
Epoch: 1, Iter: 51200/100000, Loss: 0.11091716
Epoch: 1, Iter: 61440/100000, Loss: 0.09546113
Epoch: 1, Iter: 71680/100000, Loss: 0.09327096
Epoch: 1, Iter: 81920/100000, Loss: 0.09457105
Epoch: 1, Iter: 92160/100000, Loss: 0.0839488
Avg loss: 1.4810186971708672
Epoch: 2, Iter: 10240/100000, Loss: 0.08037728
Epoch: 2, Iter: 20480/100000, Loss: 0.0748291
Epoch: 2, Iter: 30720/100000, Loss: 0.075512275
Epoch: 2, Iter: 40960/100000, Loss: 0.068265975
Epoch: 2, Iter: 51200/100000, Loss: 0.06986205
Epoch: 2, Iter: 61440/100000, Loss: 0.06875594
Epoch: 2, Iter: 71680/100000, Loss: 0.069059744
Epoch: 2, Iter: 81920/100000, Loss: 0.066677
Epoch: 2, Iter: 92160/100000, Loss: 0.071466304
Avg loss: 0.07233673165139463
Epoch: 3, Iter: 10240/100000, Loss: 0.072230816
Epoch: 3, Iter: 20480/10000

In [10]:
dtrn = nothing
Knet.gc()
GC.gc()

In [8]:
Knet.save("model_2dshapes.jld2", "model", model)

## Evaluation

In [None]:
model = Knet.load("model_2dshapes.jld2", "model")

In [9]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_eval.h5"
EVAL_BATCH_SIZE = 100

100

In [11]:
dtst = buildPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE);

In [None]:
obs, action,next_obs = first(dtst)

In [12]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = Array{Float32}(model(obs,action) + model(obs))
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,10000])
next_states = repeat(next_states, outer=[1,10000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = hcat(diag_elements, pairwise_distance_matrix)


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:10000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

Processed 10 batches
Processed 20 batches
Processed 30 batches
Processed 40 batches
Processed 50 batches
Processed 60 batches
Processed 70 batches
Processed 80 batches
Processed 90 batches
Processed 100 batches


10000×10001 Array{Int64,2}:
 1      2  8965  7190  6440  5990  …  8159  1743  6752  8857  7338  4951
 1      3  9612  3738  2390  8390     8978  6776  3375  9555  8604   883
 1      4  5277  1870  2475  8553     5061  4637  4265  4519  4821  7183
 1      5  1189  3715  8081  6116     4769  2051  6642  7141  3900  5075
 1      6  9560  4509  2771  6326      808  8877  8265   353  7263  3265
 1      7  7404  9565   183  3358  …  3265  7580  6765  6642  6054  4643
 1      8  5005   748  2635  5198     5759  1141  8589  1267  9708  1146
 1      9  6304  5850  3592  5958     9118  8432  7093  3098  2735  4693
 1     10  6050  4065  8034  1643     8000  9464  3297  6340  5301   392
 1     11   495  7878  3159  4001     4258   786  4680  4218  6907  6898
 1     12  2305  5498  6451  3559  …  2336  8681  2437  2796  2234  3390
 1     13  7733  1995  1305  3850     3013  1425  5934   741  3802  2959
 1     14  3446  8219  1219  3536     7690  5200  5555  5648  5767   612
 ⋮                     

In [13]:
num_matches = sum(labels .== indices[:,1])
hits_at_1 += num_matches
println("Hits @ 1: ", hits_at_1/num_samples)

Hits @ 1: 0.9997


In [14]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

MRR: 0.99985
