### Imports

In [1]:
using Pkg; Pkg.activate("..")

[32m[1mActivating[22m[39m environment at `~/Code/Research/nw2vec/Project.toml`


In [2]:
include("layers.jl")
include("utils.jl")
include("vae.jl")
using .Utils, .Layers, .VAE

using Flux
using LightGraphs
using Colors
using AbstractPlotting
using Makie
using CairoMakie
using BSON: @load
using NPZ
using Random
using StatsBase
using Distributions
using LinearAlgebra
using DataFrames
using Combinatorics

┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1240
┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/LightGraphs/Xm08G.ji for LightGraphs [093fc24a-ae57-5d10-9952-331d41423f4d]
└ @ Base loading.jl:1240
┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Distributions/xILW0.ji for Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]
└ @ Base loading.jl:1240
┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/AbstractPlotting/6fydZ.ji for AbstractPlotting [537997a7-5e4e-5d89-9595-2241ea00577e]
└ @ Base loading.jl:1240
┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Makie/iZ1Bl.ji for Makie [ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a]
└ @ Base loading.jl:1240
┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/CairoMakie/9mSey.ji for CairoMakie [13f3f980-e62b-5c42-98c6-ff1f3baf88f0]
└ @ Base loading.jl:1240
┌ Info

## Loading the dataset

In [3]:
function dataset(args)
    data = npzread(args["dataset"])

    features = convert(Array{Float32}, transpose(data["features"]))
    classes = transpose(data["labels"])

    # Make sure we have a non-weighted graph
    @assert Set(data["adjdata"]) == Set([1])

    # Remove any diagonal elements in the matrix
    rows = data["adjrow"]
    cols = data["adjcol"]
    nondiagindices = findall(rows .!= cols)
    rows = rows[nondiagindices]
    cols = cols[nondiagindices]
    # Make sure indices start at 0
    @assert minimum(rows) == minimum(cols) == 0

    # Construct the graph
    edges = LightGraphs.SimpleEdge.(1 .+ rows, 1 .+ cols)
    g = SimpleGraphFromIterator(edges)

    # Check sizes for sanity
    @assert nv(g) == size(g, 1) == size(g, 2) == size(features, 2)

    # Randomize to the level requested
    nnodes = nv(g)
    correlation = args["forced-correlation"]
    nshuffle = Int(round((1 - correlation) * nnodes))
    idx = StatsBase.sample(1:nnodes, nshuffle, replace = false)
    shuffledidx = shuffle(idx)
    features[:, idx] = features[:, shuffledidx]
    classes[:, idx] = classes[:, shuffledidx]

    @assert eltype(features) == Float32
    g, features, classes
end

dataset (generic function with 1 method)

In [4]:
function make_dataset()
    g = SimpleGraph(9)
    add_edge!(g, 1, 2)
    add_edge!(g, 1, 3)
    add_edge!(g, 1, 4)
    add_edge!(g, 1, 5)
    add_edge!(g, 5, 6)
    add_edge!(g, 5, 7)
    add_edge!(g, 5, 8)
    add_edge!(g, 8, 9)
    
    features = vcat(ones(Float32, (1, 9)), Array(Diagonal(ones(Float32, 9))))
    
    g, features
end

make_dataset (generic function with 1 method)

## Plotting model state

In [5]:
adims(a, dims) = [a[i, :] for i = dims]

function plotstate(;enc, vae, x, refx, g, dims, colors)
    @assert length(dims) in [2, 3]
    embμ, emblogσ = enc(x)
    logitÂ, unormF̂ = vae(x)
    hbox(
        vbox(
            Scene(),
            heatmap(σ.(logitÂ).data, colorrange = (0, 1)),
            heatmap(1:size(x, 1), 1:size(x, 2), softmax(unormF̂).data, colorrange = (0, 1)),
            sizes = [.45, .45, .1]
        ),
        vbox(
            scatter(adims(embμ, dims)..., color = colors, markersize = Utils.markersize(embμ)),
            heatmap(Array(adjacency_matrix(g)), colorrange = (0, 1)),
            heatmap(1:size(x, 1), 1:size(x, 2), refx, colorrange = (0, 1)),
            sizes = [.45, .45, .1]
        ),
    )
end

function plotweights(layers...)
    theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)
    vbox([hbox(heatmap(l.W.data), text(theme, repr(l))) for l in layers]...)
end

plotweights (generic function with 1 method)

## Model parameters

In [6]:
args = Dict(
    #"dataset" => "../data/twitter/git=679a9eb593-csv_to_npz-mt=5-tmw=3-w2v_dim=50-w2v_iter=10-cho=True-nclusters=10,20,50,80,100/dataset=retweetsrange-nclusters=10.npz",
    "forced-correlation" => 1.0, # default
    "label-distribution" => VAE.label_distributions["bernoulli"],
    "diml1enc" => 32,
    "diml1dec" => 32,
    "dimxiadj" => 16,
    "dimxifeat" => 16,
    "overlap" => 8,
    "bias" => false,
    "sharedl1" => false,
    "decadjdeep" => true,
    "initb" => VAE.Layers.nobias
)

Dict{String,Any} with 11 entries:
  "diml1dec"           => 32
  "sharedl1"           => false
  "label-distribution" => Bernoulli
  "diml1enc"           => 32
  "bias"               => false
  "overlap"            => 8
  "initb"              => nobias
  "decadjdeep"         => true
  "dimxiadj"           => 16
  "forced-correlation" => 1.0
  "dimxifeat"          => 16

## Load the model and plot its state

In [7]:
#g, _features, _ = dataset(args)
g, _features = make_dataset()
labels = _features
feature_size = size(_features, 1)
label_size = feature_size
fnormalise = Utils.normaliser(_features)
features = fnormalise(_features);

In [8]:
@load "../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-weights.bson" weights args

enc, sampleξ, dec, paramsenc, paramsdec = VAE.make_vae(
    g = g, feature_size = feature_size, label_size = label_size, args = args)
vae(x) = dec(sampleξ(enc(x)...))

paramsvae = Tracker.Params()
push!(paramsvae, paramsenc..., paramsdec...)
loadparams!(paramsvae, weights)

Info: using unshared l1 encoder
Info: using deep adjacency decoder
Info: using boolean feature decoder


In [9]:
history = npzread("../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-history.npz")

theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)
scene = vbox([hbox(lines(1:length(history[name]), history[name], color = color), text(theme, name))
        for (name, color) in [
                ("total loss", :blue),
                ("kl", :red),
                ("reg", :red),
                ("adj", :green),
                ("feat", :green),
                #("ap", :cyan),
                #("auc", :cyan)
            ]]...)
#Makie.save("training-history.png", scene)
display(scene)

GLMakie.Screen(...)

In [10]:
#communities = [c for c in 1:args["l"] for i in 1:args["k"]]
#palette = distinguishable_colors(args["l"])
#colors = map(i -> getindex(palette, i), communities)

display(
    plotstate(enc = enc, vae = vae, x = _features, refx = labels,
        g = g, dims = 1:3, colors = "black")
)

GLMakie.Screen(...)

## Distributions of gradients

In [9]:
ae(x) = dec(enc(x)[1])

ae (generic function with 1 method)

In [10]:
function make_onehot(coord, dims)
    out = zeros(Float32, dims)
    out[coord] = 1
    out
end

make_onehot (generic function with 1 method)

### w.r.t. features

In [36]:
Apred, back = Tracker.forward(x -> σ.(ae(x)[1]), features)

(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.7637189] (tracked), getfield(Tracker, Symbol("##21#23")){getfield(Tracker, Symbol("##18#19")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box((Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked),)), getfield(Tracker, Symbol("##18#19")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked)]), Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469

Pre-compute all shortest path distances

In [13]:
distances = floyd_warshall_shortest_paths(g).dists

9×9 Array{Int64,2}:
 0  1  1  1  1  2  2  2  3
 1  0  2  2  2  3  3  3  4
 1  2  0  2  2  3  3  3  4
 1  2  2  0  2  3  3  3  4
 1  2  2  2  0  1  1  1  2
 2  3  3  3  1  0  2  2  3
 2  3  3  3  1  2  0  2  3
 2  3  3  3  1  2  2  0  1
 3  4  4  4  2  3  3  1  0

Get all gradients w.r.t. involved neighbours for each link prediction

In [30]:
nnodes = size(Apred)[1]
@assert nnodes == size(Apred)[2]
nfeatures = size(features)[1]
nhops = 2
onehot = zeros((nnodes, nnodes))
coord = CartesianIndex(1, 1)

df = DataFrame(;((Symbol("grad$i"), Float32[]) for i in 1:nfeatures)...,
    u = Int64[], v = Int64[], n = Int64[], dist_uv = Int64[], dist_uorv = Int64[])

@showprogress for (u, v) in combinations(1:nnodes, 2)
    #print("$u - $v: ")
    
    # Get gradients for the u-v prediction
    global coord
    onehot[coord] = 0
    coord = CartesianIndex(u, v)
    onehot[coord] = 1
    @assert onehot == make_onehot(CartesianIndex(u, v), (nnodes, nnodes))
    grads = Tracker.data(back(onehot)[1])

    # Get the gradient for each 2-hop neigbour of u, v
    dist_uv = distances[u, v]
    neighbours_uv = collect(union(neighborhood(g, u, nhops), neighborhood(g, v, nhops)))
    non_neighbours_uv = setdiff(1:nnodes, neighbours_uv)
    
    # All gradients for neighbours should be non-null
    @assert all(sum(grads[:, neighbours_uv] .!= 0, dims = 1) .> 0)
    # All gradients for non-neighbours should be null
    @assert all(sum(grads[:, non_neighbours_uv] .== 0, dims = 1) .> 0)

    #println("$(length(neighbours_uv)) neighbours to both")
    for n in neighbours_uv
        push!(df, (grads[:, n]..., u, v, n, dist_uv, min(distances[u, n], distances[v, n])))
    end
end

In [35]:
ENV["COLUMNS"] = 300
df

Unnamed: 0_level_0,grad1,grad2,grad3,grad4,grad5,grad6,grad7,grad8,grad9,grad10,u,v,n,dist_uv,dist_uorv
Unnamed: 0_level_1,Float32,Float32,Float32,Float32,Float32,Float32,Float32,Float32,Float32,Float32,Int64,Int64,Int64,Int64,Int64
1,0.0645079,-0.021225,0.00150802,0.00318302,-0.0319299,0.00660334,0.00870657,0.0306246,-0.00391502,-0.0416326,1,2,1,1,0
2,0.0653441,-0.0255881,-0.0114106,0.0108165,-0.0316851,-0.00136565,-0.00199008,0.0298564,0.00248509,-0.0377283,1,2,2,1,0
3,0.036182,-0.0138953,0.00824023,-0.00308798,-0.0130371,0.0059515,0.00703448,0.0184967,-0.00519035,-0.0281207,1,2,3,1,1
4,0.0292288,-0.0099957,0.00241025,0.00671407,-0.00923298,0.00159172,0.00227589,0.0133913,-0.00464004,-0.0223713,1,2,4,1,1
5,0.0139868,-0.00648846,-0.00108487,0.00196581,-0.00564524,0.00110447,0.000616112,0.00652828,-0.00118362,-0.0109085,1,2,5,1,1
6,0.00515707,-0.00153295,-9.54022e-5,-0.00106444,-0.00479411,0.00258529,0.00279811,0.00244054,-0.000862602,-0.00403407,1,2,6,1,2
7,0.00515707,-0.00153295,-9.54022e-5,-0.00106444,-0.00479411,0.00258529,0.00279811,0.00244054,-0.000862602,-0.00403407,1,2,7,1,2
8,0.00421073,-0.00125165,-7.78955e-5,-0.000869113,-0.00391437,0.00211088,0.00228465,0.00199269,-0.000704312,-0.0032938,1,2,8,1,2
9,0.0807522,-0.0227874,0.0254424,-0.0167195,-0.0220161,0.0101796,0.0140722,0.0501589,-0.00739399,-0.0722438,1,3,1,1,0
10,0.0340761,-0.0103437,-0.00257617,0.00722817,-0.0177132,0.00453164,0.00870965,0.00973111,-0.00574675,-0.0162742,1,3,2,1,1


**TODO**

Plot facets of that data.

### w.r.t. embeddings

In [111]:
Apred_emb, back_emb = Tracker.forward(x -> σ.(dec(x)[1]), enc(features)[1])

(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.78653204 … 0.73746455 0.57887185; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.76371896] (tracked), getfield(Tracker, Symbol("##21#23")){getfield(Tracker, Symbol("##18#19")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box(([1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.26693793304649993 -0.6054415825060152; -0.191996930972175 -0.3271772580544453 … -0.5457029432684755 -0.8966404562162545] (tracked),)), getfield(Tracker, Symbol("##18#19")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([[1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.2669379330464

In [113]:
back_emb(onehot(CartesianIndex(1, 9), (9, 9)))[1]

Tracked 24×9 Array{Float64,2}:
  0.0506319   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.10493   
 -0.0837624   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.234714  
  0.0126053   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.028999  
 -0.0347072   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.223553  
  0.111203    0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0901013 
  0.0937184   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0224072 
 -0.024323    0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.119372  
  0.0506258   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.102253  
 -0.00647627  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0671987 
  0.0500753   0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.00589014
 -0.00719932  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0790304 
  0.00181941  0.0  0.0  0.0  0.0  0.0  0.0  0.0  -0.0755756 
 -0.0445624   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0485049 
  0.00678587  0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.15691   
  0.0739295   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.0631873 
 -0.0275066   0.0  0.0  0.0  0.0  0.0  0.0  0.0   0.05

**TODO**

Precompute all mutual distances.

Then:
- for each couple of nodes u,v (linked or not):
  - get the gradients w.r.t. embeddings
  - for both u and v, store the gradient in pandas data along with distance between u and v

Then look at the average gradient, faceted by distance between u and v

## Saving the plots

In [None]:
embμ, emblogσ = enc(features)
logitÂ, unormF̂ = vae(features);

In [None]:
# Adjacency reconstruction
scene = Scene(resolution = (15000, 15000))
heatmap!(scene, σ.(logitÂ).data, colorrange = (0, 1))
Makie.save("Apred.png", scene);

In [None]:
# Adjacency reference
scene = Scene(resolution = (15000, 15000))
heatmap!(scene, Array(adjacency_matrix(g)), colorrange = (0, 1))
Makie.save("Aref.png", scene);

In [None]:
# Feature reconstruction
scene = Scene(resolution = (150, 15000))
heatmap!(scene, 1:size(_features, 1), 1:size(_features, 2), softmax(unormF̂).data, colorrange = (0, 1))
Makie.save("Fpred.png", scene);

In [None]:
# Feature reference
scene = Scene(resolution = (150, 15000))
heatmap(1:size(_features, 1), 1:size(_features, 2), _features, colorrange = (0, 1))
Makie.save("Fref.png", scene);