#  _dMasif for ADP binding site prediction_

Submission for the Geometric Deep Learning assesment, HT23.
Candidate Number: 1045801

This notebook contains a stand-alone implementation of the dMasif model for protein binding site prediction. 


In [None]:
#@title Download dataset
!wget https://github.com/candidate1045801/miniproject/raw/main/data.zip
!unzip /content/data.zip
!rm /content/data.zip


## Julia notebook installation instructions
1. Change runtime type to use GPU acceleration
2. Execute the following cell to install the Julia kernel together with all the necessary packages. This may take a couple of minutes.
3. Reload this page and continue to the next section.


In [None]:
# Credit for this code: https://github.com/ageron/julia_notebooks

# Assuming python3 runtime
%%shell
set -e

#---------------------------------------------------#
JULIA_VERSION="1.8.5" # any version ≥ 0.7.0
JULIA_PACKAGES="IJulia"
JULIA_PACKAGES_IF_GPU="CUDA" # or CuArrays for older Julia versions
JULIA_NUM_THREADS=2
#---------------------------------------------------#

if [ -z `which julia` ]; then
  # Install Julia
  JULIA_VER=`cut -d '.' -f -2 <<< "$JULIA_VERSION"`
  echo "Installing Julia $JULIA_VERSION on the current Colab Runtime..."
  BASE_URL="https://julialang-s3.julialang.org/bin/linux/x64"
  URL="$BASE_URL/$JULIA_VER/julia-$JULIA_VERSION-linux-x86_64.tar.gz"
  wget -nv $URL -O /tmp/julia.tar.gz # -nv means "not verbose"
  tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1
  rm /tmp/julia.tar.gz

  # Install Packages
  nvidia-smi -L &> /dev/null && export GPU=1 || export GPU=0
  if [ $GPU -eq 1 ]; then
    JULIA_PACKAGES="$JULIA_PACKAGES $JULIA_PACKAGES_IF_GPU"
  fi
  for PKG in `echo $JULIA_PACKAGES`; do
    echo "Installing Julia package $PKG..."
    julia -e 'using Pkg; pkg"add '$PKG'; precompile;"' &> /dev/null
  done

  # Install kernel and rename it to "julia"
  echo "Installing IJulia kernel..."
  julia -e 'using IJulia; IJulia.installkernel("julia", env=Dict(
      "JULIA_NUM_THREADS"=>"'"$JULIA_NUM_THREADS"'"))'
  KERNEL_DIR=`julia -e "using IJulia; print(IJulia.kerneldir())"`
  KERNEL_NAME=`ls -d "$KERNEL_DIR"/julia*`
  mv -f $KERNEL_NAME "$KERNEL_DIR"/julia  

  echo ''
  echo "Successfully installed `julia -v`!"
  echo "Please reload this page (press Ctrl+R, ⌘+R, or the F5 key) then"
  echo "jump to the 'Check the Installation' section."
fi

In [1]:
#@title Check the installation
display(versioninfo())
println()
try
    using CUDA
catch
    println("No GPU found.")
else
    run(`nvidia-smi`) 
end

nothing

Julia Version 1.8.1
Commit afb6c60d69 (2022-09-06 15:09 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 12 × Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, skylake)
  Threads: 4 on 12 virtual cores
Environment:
  JULIA_DEPOT_PATH = D:\Programs\julia_depot
  JULIA_NUM_THREADS = 4





Mon Apr 10 20:25:11 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 496.76       Driver Version: 496.76       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   47C    P8    N/A /  N/A |     75MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

Process(`[4mnvidia-smi[24m`, ProcessExited(0))

In [2]:
# @title Import the required packages

# Machine Learning
using CUDA
using Flux
using OneHotArrays

using IJulia
using BenchmarkTools
using Logging
global_logger(ConsoleLogger())
using LinearAlgebra
using Distributions
using Distances
using Distributed
using LogExpFunctions: logsumexp as lsumexp
# using Plots

using Random 
Random.seed!(1234);

# Part 1: Generate surface data from atom data

In [3]:
using BioStructures: resnameselector, collectresidues, coordarray, element, collectatoms,
                     read, PDB, StructuralElement, standardselector
                     
using Bio3DView
#@title Load input atom data

adpselector(res) = resnameselector(res, ["ADP"])
const ELEMENTS = ["C", "H", "N", "O", "S", "SE"]
const ELEMENT_IDS = Dict([(ELEMENTS[i], i) for i in eachindex(ELEMENTS)])
const RADII = [1.7, 1.1, 1.52, 1.55, 1.80, 1.90] # van der waals radii

function get_atom_data(struc::StructuralElement)
    aminoacids = collectresidues(struc, standardselector)
    coords = coordarray(aminoacids)

    types = get.(Ref(ELEMENT_IDS), element.(collectatoms(aminoacids); strip=true), 2)

    adps = collectresidues(struc, adpselector)
    adp_coords = coordarray(adps)

    # hack to be able to handle proteins which don't bind to ADP
    if isempty(adp_coords)
        adp_coords = [1e6, 1e6, 1e6]
    end
    return coords, types, adp_coords
end

struc = read("data/1AE4.pdb", PDB)
coords, types, adp_coords = get_atom_data(struc)
display(struc)

ProteinStructure 1AE4.pdb with 1 models, 1 chains (A), 324 residues, 434 atoms

In [4]:
#@title Optimized sdf functions 
# low memory implementation of SDF
# to make faster - batch it up to enable parallelism
function sdf(x, coords, radii; ideal_dist = 1.05)
    a = coords
    vecs = similar(a)
    d = similar(a, 1, size(a, 2))
    v = similar(d)
    res = []
    for x_i in eachcol(x)
        vecs .= (x_i .- a).^2
        sum!(d, vecs)                               # d = sqdists(x_i, a)
        d .= .-sqrt.(d)                             # d = -dists(x_i, a)
        v .= d ./ radii'
        L = lsumexp(v)                              # L = logsumexp(-dists(x_i, a) / radii)
        d .= exp.(d)                                # d = exp.(-dists(x_i, a))
        σ = dot(d, radii') / sum(d)                 # σ = smoothed mean atom radius weighted by exp(-dists)
        push!(res, -σ*L)
    end
    return res .- ideal_dist
end


# memory heavy but correct
function heavy_sdf(x, coords, radii; ideal_dist = 1.05)
    sqdists = sum((reshape(x, 3, :, 1) .- reshape(coords, 3, 1, :)).^2; dims=1)
    dists = sqrt.(dropdims(sqdists; dims=1))
    expneg_dists= exp.(-dists)
    softavg_nbhrad = sum(expneg_dists .* radii'; dims=2) ./ sum(expneg_dists; dims=2)

    return vec(-softavg_nbhrad .* lsumexp(-dists ./ radii'; dims=2) .- ideal_dist)
end

#∇ₓSDF - low memory - to make faster should operate in bigger batches of columns over x
function grad_sdf(x, coords, radii)
    a = coords
    vecs = similar(a)
    d = similar(a, 1, size(a, 2)) # -dists
    ed = similar(d) # exp(-dists)
    nd = similar(d) # -dists(x_i, a) / radii (normalized dists by radius)
    grads = similar(x)

    for (i, x_i) in enumerate(eachcol(x))
        vecs .= (x_i .- a).^2
        sum!(d, vecs)                               # d = sqdists(x_i, a)
        d .= .-sqrt.(d)                             # d = -dists(x_i, a)
        nd .= d ./ radii'                           # nd = -dists(x_i, a) / radii
        ed .= exp.(d)                               # ed = exp(-dists)

        L = lsumexp(nd)                             # L = logsumexp(-dists(x_i, a)/radii)
        ϕ = sum(ed)
        ψ = dot(ed, radii')
        σ = ψ / ϕ

        # add σ*∇L
        vecs .= exp.(nd)  ./ (d .* radii') .* (x_i .- a)
        @views grads[:, i] .= -σ * sum(vecs;dims=2) / exp(L)

        # add ∇ϕ component
        vecs .= (ed ./ d) .* (x_i .- a) # ∇ϕ component
        @views grads[:, i] .-= -ψ * L * sum(vecs;dims=2) / ϕ^2

        # add ∇ψ component
        vecs .*= radii' # ∇ψ component
        @views grads[:, i] .-= ϕ * L * sum(vecs;dims=2) / ϕ^2
    end

    return grads
end

# Profiling and correctness checks 
function check_sdf()
    xx = rand(3, 4000)
    c = rand(3, 2000)
    r = rand(2000)

    println("My sdf")
    @time sdf(xx, c, r)
    println("Heavy sdf")
    @time heavy_sdf(xx, c, r)

    @assert isapprox(sdf(xx, c, r), heavy_sdf(xx, c, r))

    msd(x) = sum(heavy_sdf(x, c, r))
    println("My grad")
    @time grad_sdf(xx, c, r)
    println("Heavy grad")
    @time gradient(msd, xx)[1]

    @assert isapprox(grad_sdf(xx, c, r), gradient(msd, xx)[1])
end
#check_sdf()

check_sdf (generic function with 1 method)

In [5]:
function sample_surface(coords, radii; 
                        samples_per_atom=1, num_iters=10, step_size=3.0, error_margin=0.3,
                        batch_size=2000)
    A = size(coords, 2)
    B = samples_per_atom

    mysdf(p) = sdf(p, coords, radii)
    mygrad_sdf(p) = grad_sdf(p, coords, radii)

    # Step 1 - Sample point cloud around atoms
    @info "initial #samples " * string(A*B)
    x = rand(Normal(0.0, 1.0), 3, A, B) 
    x .= x .* radii' .+ coords
    x = reshape(x, 3, :) # size(x) = (3, A*B)

    # Step 2 - Bring points closer to surface by minimizing the squared sdf
    #          via gradient descent
    @info "Attracting samples to surface via gradient descent"
    batches = collect(Iterators.partition(axes(x, 2), batch_size))
    for i in 1:num_iters
        #n.b. this could be done on the gpu and would be much faster
        Threads.@threads for cols in batches
            @views x_batch = x[:, cols]
            x_batch .-= step_size .* mysdf(x_batch)' .* mygrad_sdf(x_batch) # grad(mse(sdf))
        end
        @info "Loss at iter $i: " * string(mean(mysdf(x).^2) / 2)
    end

    # Step 3 - Clean the samples

    # Discard samples far away from the surface
    mask = abs.(mysdf(x)) .< error_margin
    x = x[:, mask]
    @info "#samples left after distance cull: " * string(count(mask))

    # Compute normals as gradient of surface implicit function (aka sdf)
    normals = mygrad_sdf(x)
    foreach(normalize!, eachcol(normals))

    # Discard samples nested inside the protein 
    # i.e. if moving "upwards" by 4Å actually reduces the distance to the surface
    mask = (mysdf(x .+ 4 .* normals) .- mysdf(x)) .> 0.5
    x = x[:, mask]
    normals = normals[:, mask]
    @info "#samples left after trapped cull: " * string(count(mask))

    # Step 4 - subsampling - make sure sampling was uniform in space
    # For each cubic bin of size 0.5Å we keep one sample per cell

    # technically this is not subsampling because it is deterministic and always chooses 
    # one point per grid; might give slightly biased results
    grid_loc = floor.(Int, x)
    unique_idx = unique(i -> grid_loc[:, i], 1:size(x, 2))
    x = x[:, unique_idx]
    normals = normals[:, unique_idx]
    @info "#samples left after subsampling: " * string(length(unique_idx))

    return x, normals
end

radii = RADII[types]
x, normals = sample_surface(coords, radii);

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39minitial #samples 324


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mAttracting samples to surface via gradient descent


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 1: 0.19370858935230315
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 2: 0.12540116565950285
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 3: 0.08664582197545564
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 4: 0.06127822426842723
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 5: 0.04399845196317627
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 6: 0.03183763669924627
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 7: 0.022322739468219284


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 8: 0.015596193537336952
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 9: 0.01169031235377825
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 10: 0.009135975242368552


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after distance cull: 308


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after trapped cull: 208
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after subsampling: 207


In [6]:
# Returns for each x_i in x the indices of the kth nearest points in coords and the 
# distances to them
function knearest(x, coords; k=16)
    vecs = similar(coords)
    sqdists = similar(coords, size(coords, 2))
    ids = Array{Integer}(undef, k, size(x, 2))
    dists = similar(x, k, size(x, 2))
    # hot loop
    for (i, x_i) in enumerate(eachcol(x))
        vecs .= (x_i .- coords).^2
        sqdists .= dropdims(sum(vecs; dims=1);dims=1)                             
        @views ids[:, i] .= sortperm(sqdists; alg=PartialQuickSort(k))[1:k]
        @views dists[:, i] .= sqrt.(sqdists[ids[:, i]])
    end
    return ids, dists
end

nbh_atom_ids, dists = knearest(x, coords; k=8)

t = reshape(types[vec(nbh_atom_ids)], size(nbh_atom_ids))
t_onehot = onehotbatch(t, 1:6, 2)
inv_dists = 1 ./ dists # size = 16 x num_samples

labels = knearest(x, adp_coords; k=1)[2] .< 3.0

count(labels)

0

In [7]:
# geodesic distance

function dists_dots(x, n, nbh_ids)
    k = size(nbh_ids, 1)
    vecs = similar(x, 3, k)
    v = similar(x, k)
    dists = similar(x, k, size(x, 2))
    dots = similar(dists)
    # hot loop
    for i in axes(x, 2)
        @views x_i = x[:, i]; n_i = n[:, i]
        vecs .= (x_i .- x[:, nbh_ids[:, i]]).^2
        v .= dropdims(sum(vecs; dims=1);dims=1)                             
        @views dists[:, i] .= sqrt.(v)

        vecs .= (n_i .* n[:, nbh_ids[:, i]])
        v .= dropdims(sum(vecs; dims=1);dims=1)  
        @views dots[:, i] .= v
    end
    return dists, dots
end

function quasi_geodesic_dist(x, n, nbh_ids; λ=1)
    dists, dots = dists_dots(x, n, nbh_ids)

    return dists .* (1 .+ λ .* (1 .- dots))
end

function my_geodesic_dist(x, n, nbh_ids)
    dists, dots = dists_dots(x, n, nbh_ids)
    clamp!(dots, -1.0+1e-14, 1.0-1e-14)
    return dists .* acos.(dots) ./ sqrt.(2 .- 2 .* dots)
end

# Process dists so that they can be used inside the gaussian filter
function gaussian_filter!(dists; σ=9)
    dists .= exp.(.-dists.^2 ./ (2 .* σ^2))
    return dists
end

nbh_ids, dists = knearest(x, x; k=100);

In [11]:
#@title 3D geometry methods

# Generate arbitrary local reference frames for each normal
function nuv_from_n(n)
    @views x = n[1, :]; y = n[2, :]; z = n[3, :]
    s = sign.(z)
    a = -1.0 ./ (s .+ z)
    b = a .* x .* y

    u = similar(n)
    u[1, :] .= 1 .+ s .* a .* x .* x 
    u[2, :] .= s .* b 
    u[3, :] .= .- s .* x 

    v = similar(n)
    v[1, :] .= b 
    v[2, :] .= s .+ a .* y .* y 
    v[3, :] .= .-y

    nuv = cat(reshape(n, 3, 1, :) , reshape(u, 3, 1, :), reshape(v,3, 1,:); dims=2)
    # size nuv = (3, 3, num_samples)
    return nuv
end

# Compute positions of neighbours w.r.t local reference frames
function local_pos(pos, frame, nbh_ids)
    # size(frame, 1) = 3
    k = size(frame, 2) # ∈ [1, 2, 3]
    @views res = map(i -> frame[:, :, i]' * (pos[:, nbh_ids[:, i]] .- pos[:, i]),axes(pos, 2))
    # size res[i] = (k, 3) * (3, nbh_size) = (k, nbh_size)

    # for i in axes(x, 2)
    #     @views x_i = x[:, i]
    #     @views local_frame = nuv[:, :, i]
    #     @views x_j = x[:, nbh_ids[:, i]]
    #     push!(res, local_frame' * (x_j .- x_i)) # (3, 3) * (3, nbh_size) = (3, nbh_size)
    # end

    # final size: (k, nbh_size, num_samples)
    return reshape(cat(res...;dims=2), k, size(nbh_ids, 1), size(pos, 2))
end

nuv = nuv_from_n(normals)
local_pos(x, nuv, nbh_ids);

In [12]:
#@title Putting it all together
function process_data_from_pdb(id::String; 
                               atom_nbh_size=8, ligand_bind_range=3.0)
    struc = read("data/" * id * ".pdb", PDB)
    atom_coords, atom_types, adp_coords = get_atom_data(struc)
    atom_radii = RADII[atom_types]

    # Sample oriented point cloud surface from atom metaball 
    pos, normals = sample_surface(atom_coords, atom_radii; num_iters=10)
    nbh_atom_ids, atom_dists = knearest(pos, atom_coords; k=atom_nbh_size)

    sample_types = reshape(atom_types[vec(nbh_atom_ids)], size(nbh_atom_ids))
    types_onehot = onehotbatch(sample_types, 1:6, 2) 
    inv_dists = 1 ./ atom_dists # atom_nbh_size x num_samples

    # 7 x atom_nbh_size x num_samples
    feats = cat(types_onehot, reshape(inv_dists, 1, atom_nbh_size, :) ;dims=1)

    labels = knearest(pos, adp_coords; k=1)[2] .< ligand_bind_range

    # For tractability reasons, only convolve over the closest k samples w.r.t L2 distance
    #nbh_surface_ids, _ = knearest(x, x; k=surface_nbh_size=300)
    #weights = geodesic_dist(x, n, nbh_surface_ids)
    #gaussian_filter!(weights)

    return (pos=pos, normals=normals, feats=feats, labels=labels)
end

@time data = process_data_from_pdb("1AE4");

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39minitial #samples 324
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mAttracting samples to surface via gradient descent
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 1: 0.18391514102897097
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 2: 0.11200804421740201
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 3: 0.07919686823321234
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 4: 0.058603818255943664
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 5: 0.04469040299882387


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 6: 0.03278820777175682
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 7: 0.023812487786767422
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 8: 0.017212012556286533
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 9: 0.013739309098077964


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mLoss at iter 10: 0.011339167793507451
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after distance cull: 302
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after trapped cull: 224
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m#samples left after subsampling: 221


  0.722646 seconds (1.65 M allocations: 84.914 MiB, 4.87% gc time, 60.35% compilation time)


# Part 2: Learning over the protein surface data

In [14]:
#@title Chemical Layer - compute chemical embedding of points based on surrounding atoms 

struct ChemicalLayer
    atomMLP::Chain
    sampleMLP::Chain
end 
function ChemicalLayer(;input_dim=7, num_atoms=8, hidden_dim=12, emb_dim=6)
    atomMLP = Chain(Dense(input_dim => hidden_dim),
                    BatchNorm(num_atoms, leakyrelu),
                    Dense(hidden_dim => emb_dim))
    sampleMLP = Chain(Dense(emb_dim => hidden_dim),
                      BatchNorm(hidden_dim, leakyrelu),
                      Dense(hidden_dim => emb_dim))
    return ChemicalLayer(atomMLP, sampleMLP)
end
function(m::ChemicalLayer)(x)
    # Embed atoms in chemical space using an MLP
    x = m.atomMLP(x)
    # Aggregate for each sample, using a second MLP
    x = dropdims(sum(x; dims=2);dims=2)
    return m.sampleMLP(x)
end

Flux.@functor ChemicalLayer
model = ChemicalLayer() |> gpu
chem_emb = model(Float32.(data[:feats]) |> gpu) |> cpu

6×221 Matrix{Float32}:
 -0.440055  -0.422967  -0.43297   …  -0.422641  -0.414309  -0.430095
 -1.3341    -1.323     -1.32733      -1.32007   -1.31976   -1.32701
 -2.44653   -2.45781   -2.44399      -2.44898   -2.47146   -2.45105
  0.938986   0.932628   0.9327        0.927933   0.933421   0.934243
 -2.93426   -2.92887   -2.92688      -2.92232   -2.9318    -2.92966
  0.176727   0.183002   0.179153  …   0.182902   0.186374   0.180335

In [16]:
# Scalar field over protein surface
# Projecting its gradient onto the tangent plane yields a smooth vector field 
# This ensures that the gauge of the manifold is smooth and consistent which is important,
# because the convolution of the signal depends on the specific coordinate choice
struct PotentialLayer
    chain::Chain
end 
function PotentialLayer(;input_dim=6, hidden_dim=16)
    chain = Chain(Dense(input_dim => hidden_dim),
                  BatchNorm(hidden_dim, leakyrelu),
                  Dense(hidden_dim => hidden_dim),
                  BatchNorm(hidden_dim, leakyrelu),
                  Dense(hidden_dim => 1))
    return PotentialLayer(chain)
end
function(m::PotentialLayer)(x)
    return m.chain(x)
end
Flux.@functor PotentialLayer

potmodel = PotentialLayer() |> gpu 
pot = potmodel(chem_emb |> gpu) |> cpu

1×221 Matrix{Float32}:
 0.196107  0.212117  0.201956  0.206608  …  0.211434  0.221096  0.205215

In [17]:
function (m::ConvLayer)(emb, window, p_ij, nbh_ids)
    # emb_ij = 
    @views emb_ij = cat(map(i -> emb[:, nbh_ids[:, i]], axes(emb, 2))...;dims=2)
    p_ij_uv = local_pos(pos, uv, nbh_ids)
    out = dropdims(sum(window .* m.localMLP(p_ij) .* emb_ij; dims=2);dims=2)
    return out
end

UndefVarError: UndefVarError: ConvLayer not defined

In [41]:
function update_nuv(potentials, pos, nuv, nbh_ids, window)
    num_samples = size(pos, 2)

    window = reshape(window, 1, size(window)...)
    # size window = (1, nbh_size, num_samples)

    #pots_ij = p_j - p_i 
    # i.e. change of potential when going from i to j 
    ps = potentials
    @views pots_ij = cat(map(i -> ps[:, nbh_ids[:, i]] .- ps[:, i], axes(pos, 2))...;dims=2)
    pots_ij = reshape(pots_ij, 1, :, num_samples)
    # size pots_ij = (1, nbh_size, num_samples)

    #p_ij_uv gives location of pos_j w.r.t. the tangent plane of point i 
    @views p_ij_uv = local_pos(pos, nuv[:, 2:3, :], nbh_ids)
    # size p_ij_uv = (2, nbh_size, num_samples)

    # new_u sits in the (u,v) plane
    new_u = dropdims(mean(window .* pots_ij .* p_ij_uv; dims=2);dims=2)
    # size new_u = (2, num_samples)
    # bring new_u to 3D => size = (3, num_samples)
    new_u = vcat(zeros(Float32, 1, num_samples), new_u)

    # rotate u counter clockwise by 90 degrees to get v
    # rot by 90deg: (x, y) -> (-y, x)
    @views new_v = vcat(zeros(Float32, 1, num_samples), -new_u[2, :]', new_u[1, :]')

    new_uv = cat(reshape(new_u, 3, 1, :), reshape(new_v, 3, 1, :); dims=2)
    # Finally, rotate new_u, new_v so that their normals match 
    # This amounts to reversing the local frame transformation
    @views new_uv = map(i -> inv(nuv[:, :, i])' * new_uv[:, :, i],axes(pos, 2))
    new_uv = reshape(cat(new_uv...;dims=2), 3, 2, num_samples)

    @views n = reshape(nuv[:, 1, :], 3, 1, :)
    new_nuv = cat(n, new_uv; dims=2)

    return new_nuv
end

update_nuv (generic function with 1 method)

In [42]:
struct DMasif
    chem_layer::ChemicalLayer
    potential_layer::PotentialLayer
    classifier_layer::Chain
end

function DMasif()
    chem_layer = ChemicalLayer()
    potential_layer = PotentialLayer()
    classifier_layer = Chain(Dense(6 => 16),
                             BatchNorm(16, leakyrelu),
                             Dense(16 => 1))
    return DMasif(chem_layer, potential_layer, classifier_layer)
end


function (m::DMasif)(pos, nuv, feats, nbh_ids, window)
    # Compute embedding based on chemical properties
    emb = m.chem_layer(feats)

    # Update gauge based on gradient of potential
    potentials = m.potential_layer(emb)
    nuv = update_nuv(potentials, pos, nuv, nbh_ids, window)
    
    # Apply several quasi-geodesic convolutions to update embedding
    p_ij = local_pos(pos, nuv, nbh_ids)

    #emb = m.conv_layers(emb, window, p_ij, nbh_ids)

    # Finally, classify point based on its embedding
    out = m.classifier_layer(emb)
    return out 
end
Flux.@functor DMasif

pos = data[:pos] 
nuv = nuv_from_n(data[:normals])
feats = data[:feats]
nbh_ids = copy(knearest(pos, pos; k=100)[1])
window = quasi_geodesic_dist(pos, data[:normals], nbh_ids)
gaussian_filter!(window)

model = DMasif()
affinity = model(pos, nuv, feats, nbh_ids, window) 


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m(221,)


ArgumentError: ArgumentError: number of columns of each array must match (got (221, 1, 1))