In [3]:
import ShapeRetrieval: ShapeRetrieval as SR
using ShapeRetrieval: LearnedTimeDiffusionBlock
import ShapeRetrieval: Mesh, heat_integrator
using LinearAlgebra
using SparseArrays
using Flux

### Visualization Tools

In [4]:
bunny = SR.load_obj("./meshes/gourd.obj")
bunny = SR.normalize_mesh(bunny)
println("Area: ", sum(bunny.vertex_area))

Area: 2.681867885523121


In [5]:
function spectral_loss(model, x, λ, ϕ, A, y) 
    norm(model(x, λ, ϕ, A) - y)
end

spectral_loss (generic function with 1 method)

## Example Prediction using Spectral Mode

In [6]:
m = LearnedTimeDiffusionBlock(2, :spectral)
heat_signal = zeros(bunny.nv)
heat_signal[[1, 300]] .= 1.0
λ, ϕ, A = SR.get_diffusion_inputs(bunny)
m(heat_signal, λ, ϕ, A)

326×2 Matrix{Float64}:
 8.32397e-6  8.29702e-6
 8.31679e-6  8.29011e-6
 8.33114e-6  8.30392e-6
 8.31452e-6  8.28793e-6
 8.33978e-6  8.31223e-6
 8.34207e-6  8.31444e-6
 8.30344e-6  8.27726e-6
 8.33781e-6  8.31033e-6
 8.31065e-6  8.2842e-6
 8.32873e-6  8.3016e-6
 ⋮           
 8.53269e-6  8.49781e-6
 8.53261e-6  8.49773e-6
 8.53266e-6  8.49778e-6
 8.53662e-6  8.50159e-6
 8.53666e-6  8.50163e-6
 8.53692e-6  8.50188e-6
 8.53712e-6  8.50207e-6
 8.53711e-6  8.50206e-6
 8.53771e-6  8.50264e-6

#### Sample Training 

In [7]:
data = [(heat_signal, 0.0),]
λ, ϕ, A = SR.get_diffusion_inputs(bunny)

opt_state = Flux.setup(Adam(), m);

@time for i=1:1000
    for d in data
        x, y  = d
        grad = gradient(spectral_loss, m, x, λ, ϕ, A, y)
        Flux.update!(opt_state, m, grad[1])
    end
end
println(m(heat_signal, λ, ϕ, A))

└ @ Optimisers /Users/ian/.julia/packages/Optimisers/1x8gl/src/interface.jl:173


MethodError: MethodError: no method matching -(::Matrix{Float64}, ::Float64)
For element-wise subtraction, use broadcasting with dot syntax: array .- scalar

Closest candidates are:
  -(!Matched::T, ::T) where T<:Union{Float16, Float32, Float64}
   @ Base float.jl:409
  -(!Matched::GeometryBasics.HyperRectangle{N, T}, ::Number) where {N, T}
   @ GeometryBasics ~/.julia/packages/GeometryBasics/Du43H/src/primitives/rectangles.jl:266
  -(!Matched::SIMD.Vec{N, T}, ::T2) where {N, T2<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}, T<:Union{Float32, Float64}}
   @ SIMD ~/.julia/packages/SIMD/7eukp/src/simdvec.jl:398
  ...


In [6]:
m(heat_signal, SR.get_diffusion_inputs(bunny)...)
println(m.diffusion_time)

(200, 2)


[2.9274950637401447, 0.7993044171182087]


### AD with Implicit Mode

In [7]:
m = LearnedTimeDiffusionBlock(2, :implicit)
Flux.trainable(m::LearnedTimeDiffusionBlock) = (m.diffusion_time,)

function implicit_loss(model, x, L, M, A::Vector{Float64}, y)
    norm(model(x, L, M, A) - y)
end

inputs = SR.get_diffusion_inputs(bunny,:implicit)
m(heat_signal, inputs...)

0.30148325561767775

In [49]:
heat_signal = zeros(bunny.nv)
heat_signal[[1, 300]] .= 1.0
L, M, A = SR.get_diffusion_inputs(bunny, :implicit)
display(m(heat_signal, L, M, A))
data = [(heat_signal, bunny, 0.0),]

opt_state = Flux.setup(Adam(), m)

for i=1:1000
    for d in data
        x, mesh, y  = d
        grad = gradient(implicit_loss, m, x, L,M,A, y)
        Flux.update!(opt_state, m, grad[1])
    end
end

display(m(heat_signal, L, M, A))

0.3014815281221147

0.301481517244275