# C-SWM 

In [26]:
#Libs
import Base: iterate, length, GC
using HDF5
using Knet
using Statistics: mean,std
using Random
using LinearAlgebra
using Images
using Plots

#Datatype
atype=KnetArray{Float32}

#Includes
include("datasets.jl")
include("cswm.jl")

## Model Definition

In [32]:
#Params
input_ch = 3
hidden_dim = 512
num_objects = 5
embedding_dim = 2
action_dim = 4
sigma = 0.5
hinge = 1.0

1.0

In [33]:
Knet.seed!(42)

MersenneTwister(UInt32[0x0000002a], Random.DSFMT.DSFMT_state(Int32[964434469, 1073036706, 1860149520, 1073503458, 1687169063, 1073083486, -399267803, 1072983952, -909620556, 1072836235  …  -293054293, 1073002412, -1300127419, 1073642642, 1917177374, -666058738, -337596527, 1830741494, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000  …  0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x0000000000

In [34]:
model = initContrastiveSWMSmall(input_ch, hidden_dim, num_objects, embedding_dim, action_dim, sigma, hinge)

ContrastiveSWM(EncoderCNNSmall(Any[P(KnetArray{Float32,4}(10,10,3,32)), P(KnetArray{Float32,4}(1,1,32,5))], Any[P(KnetArray{Float32,4}(1,1,32,1)), P(KnetArray{Float32,4}(1,1,5,1))], Any[Knet.BNMoments(0.1, nothing, nothing, zeros, ones), K32(64)[1.0⋯]], Knet.sigm, NNlib.relu), EncoderMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,25)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(2,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(2))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), TransitionGNN(EdgeMLP(Param{KnetArray{Float32,2}}[P(KnetArray{Float32,2}(512,4)), P(KnetArray{Float32,2}(512,512)), P(KnetArray{Float32,2}(512,512))], Param{KnetArray{Float32,1}}[P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512))], LayerNorm(P(KnetArray{Float32,1}(512)), P(KnetArray{Float32,1}(512)), 1.0e-5), NNlib.relu), NodeMLP(Para

## Training Part

In [14]:
TRAIN_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_train.h5"
TRAIN_BATCH_SIZE = 1024
dtrn = buildStateTransitionDataset(TRAIN_DATASET_PATH, true, TRAIN_BATCH_SIZE);

Dataset loaded. Building dataset indexing...
Done.


In [38]:
function initopt!(model::ContrastiveSWM)
    
    for par in params(model)
        par.opt = Adam(;lr=0.05, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
        println(par)
    end
end
initopt!(model)

P(KnetArray{Float32,4}(10,10,3,32))
P(KnetArray{Float32,4}(1,1,32,5))
P(KnetArray{Float32,4}(1,1,32,1))
P(KnetArray{Float32,4}(1,1,5,1))
P(KnetArray{Float32,2}(512,25))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,4))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,2}(512,518))
P(KnetArray{Float32,2}(512,512))
P(KnetArray{Float32,2}(2,512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(2))
P(KnetArray{Float32,1}(512))
P(KnetArray{Float32,1}(512))


In [39]:
#Verbose after x batches
VERBOSE =  5

#Define number of epochs
NUM_EPOCHS = 20

println("Starting training...")

for i in 1:NUM_EPOCHS
    
    avg_loss = 0.0
    it = 0
    for  (k, (obs, action, next_obs)) in enumerate(dtrn)

        #Train by using contrastive loss
        J = @diff model(obs,action,next_obs)
        
        for par in params(model)
            g = grad(J, par)
            println("Param: ", par, " Grad: ", g)
            update!(value(par), g, par.opt)
        end
        
        batch_size = size(obs,4)

        if k % VERBOSE == 0
            
            println("Epoch: ", i , ", Iter: " , k*batch_size, "/", dtrn.num_steps, ", Loss: ", value(J))

        end
        
        avg_loss += value(J)
        it = k
        
    end
    
    avg_loss /= it
    
    println("Avg loss: " , avg_loss)
end

#dtrn = nothing
#Knet.gc()
#GC.gc()

Starting training...
x4(2, 5120)
x4(2, 5120)
Param: P(KnetArray{Float32,4}(10,10,3,32)) Grad: K32(10,10,3,32)[0.019052802⋯]
Param: P(KnetArray{Float32,4}(1,1,32,5)) Grad: K32(1,1,32,5)[0.08232628⋯]
Param: P(KnetArray{Float32,4}(1,1,32,1)) Grad: K32(1,1,32,1)[5.9604645e-8⋯]
Param: P(KnetArray{Float32,4}(1,1,5,1)) Grad: K32(1,1,5,1)[0.19374993⋯]
Param: P(KnetArray{Float32,2}(512,25)) Grad: K32(512,25)[0.019287977⋯]
Param: P(KnetArray{Float32,2}(512,512)) Grad: K32(512,512)[0.0021520536⋯]
Param: P(KnetArray{Float32,2}(2,512)) Grad: K32(2,512)[0.0035380237⋯]
Param: P(KnetArray{Float32,1}(512)) Grad: K32(512)[0.031910717⋯]
Param: P(KnetArray{Float32,1}(512)) Grad: K32(512)[-1.3969839e-9⋯]
Param: P(KnetArray{Float32,1}(2)) Grad: K32(2)[0.0063588936⋯]
Param: P(KnetArray{Float32,1}(512)) Grad: K32(512)[0.0012128043⋯]
Param: P(KnetArray{Float32,1}(512)) Grad: K32(512)[8.897006e-5⋯]
Param: P(KnetArray{Float32,2}(512,4)) Grad: K32(512,4)[0.0058242483⋯]
Param: P(KnetArray{Float32,2}(512,512)) Grad

InterruptException: InterruptException:

In [40]:
obs,action, next_obs = first(dtrn)

(K32(50,50,3,1024)[0.0⋯], [20, 11, 5, 1, 7, 19, 7, 17, 3, 5  …  8, 3, 5, 10, 13, 16, 10, 15, 17, 4], K32(50,50,3,1024)[0.0⋯])

In [43]:
model.obj_extractor(obs)

5×5×5×1024 KnetArray{Float32,4}:
[:, :, 1, 1] =
 0.0186069  0.0186069   7.55767e-6  8.28097e-5  0.0186069  
 0.0186069  0.00379225  0.0186069   0.0186069   0.0186069  
 0.0186069  0.0186069   0.0186069   0.00129456  0.0186069  
 0.0186069  0.0186069   0.0186069   0.0186069   0.000655774
 0.0186069  0.0186069   0.0186069   0.0186069   0.0186069  

[:, :, 2, 1] =
 0.686676  0.686676  0.0842681  0.204305  0.686676
 0.686676  0.468991  0.686676   0.686676  0.686676
 0.686676  0.686676  0.686676   0.550421  0.686676
 0.686676  0.686676  0.686676   0.686676  0.295758
 0.686676  0.686676  0.686676   0.686676  0.686676

[:, :, 3, 1] =
 0.0247255  0.0247255   4.11881e-7  9.28247e-7  0.0247255  
 0.0247255  0.00163626  0.0247255   0.0247255   0.0247255  
 0.0247255  0.0247255   0.0247255   3.92528e-5  0.0247255  
 0.0247255  0.0247255   0.0247255   0.0247255   0.000369288
 0.0247255  0.0247255   0.0247255   0.0247255   0.0247255  

[:, :, 4, 1] =
 0.966281  0.966281  0.999931  0.999955  0.966281

In [None]:
dims = size(objs)
x0 = reshape(objs, dims[1]*dims[2], dims[3],dims[4]);
x0 = reshape(x0, dims[1]*dims[2],:);

In [None]:
dtrn = nothing
Knet.gc()
GC.gc()

In [None]:
Knet.save("model.jld2", "model", model)

## Evaluation

In [None]:
model = Knet.load("model.jld2", "model")

In [None]:
#Params
EVAL_DATASET_PATH = "/home/cagan/dev/datasets/C-SWM/shapes_eval.h5"
EVAL_BATCH_SIZE = 100

In [None]:
dtst = buildPathDataset(EVAL_DATASET_PATH, EVAL_BATCH_SIZE);

In [None]:
obs, action,next_obs = first(dtst)

In [None]:
model(obs)

In [None]:
pred_states = Any[]
next_states = Any[]

num_samples = dtst.dataset_size

for  (k, (obs, action, next_obs)) in enumerate(dtst)
    
    if k % 10 == 0
        
        println("Processed ", k ," batches")
        
    end
    #Obs => (50,50,3,100)
    #Next obs => (50,50,3,100)
    
    pred_state = Array{Float32}(model(obs,action))
    next_state = Array{Float32}(model(next_obs))
    
    #Pred-state => (2,5,100)
    #Next state => (2,5,100)
    #println(pred_state)
    #println(next_state)
    
    push!(pred_states, pred_state)
    push!(next_states, next_state)
    
end

#Pred state cat => [2,5,10000]
#Next state cat => [2,5,10000]
pred_states = cat(pred_states...,dims=3)
next_states = cat(next_states...,dims=3)
    
#Flatten object/feature dimensions
pred_states = mat(pred_states)  #[10,10000]
next_states = mat(next_states)  #[10,10000]

#Calculate pairwise distances
sizes_1 = (size(pred_states)...,1)
sizes_2 = (sizes_1[1], sizes_1[3], sizes_1[2])

pred_states = reshape(pred_states, sizes_1)
next_states = reshape(next_states, sizes_2)
pred_states = repeat(pred_states, outer=[1,1,10000])
next_states = repeat(next_states, outer=[1,10000,1])

pairwise_distance_matrix = sum((pred_states - next_states).^2, dims=1)[1,:,:]

#Augment pairwise distance matrix
diag_elements = diag(pairwise_distance_matrix)
pairwise_distance_matrix = vcat(pairwise_distance_matrix, diag_elements')


labels = ones(num_samples)
hits_at_1 = 0

indices = []

for i=1:10000
    
    row = pairwise_distance_matrix[i,:]
    ind = sortperm(row)
    
    push!(indices, ind)

end

indices = vcat(indices'...)

In [None]:
pairwise_distance_matrix

In [None]:
num_matches = sum(labels .== indices[:,1])
hits_at_1 += num_matches
println("Hits @ 1: ", hits_at_1/num_samples)

In [None]:
mxval, mxindx = findmax(indices .== labels,dims=2)
ranks = [ i[2] for i in mxindx ]
reciprocal_ranks = 1 ./ranks
rr_sum = sum(reciprocal_ranks)
println("MRR: ", rr_sum/num_samples)

In [None]:
model(obs)