In [8]:
using Knet
using Plots
using Reparam
using LinearAlgebra

In [338]:
module dev
using LinearAlgebra: norm
using Knet

function squaresum(x)
    out = x[1]^2
    for i in 2:length(x)
        out += x[i]^2
    end
    out
end


struct Qmap{T<:Function, TF<:Real}; c::T; h::TF; end
(q::Qmap)(x) = sqrt(norm(0.5*(q.c(x+q.h) - q.c(x-q.h)) / q.h)) * q.c(x)


struct SineLayer{T<:AbstractVector}; N::Int; c::Param{T}; end 
SineLayer(c::Vector{T})  where {T<:Real} = SineLayer(length(c), Param(c))

function (S::SineLayer)(x, y)
    z = sin(π*x) * S.c[1]
    y1 = π * cos(π*x) * S.c[1]
    for n in 2:S.N
        z += sin(n*π*x) * S.c[n]
        y1 += n * π * cos(n*π*x) * S.c[n]
    end
    return x + z, (1. + y1 )* y
end

(S::SineLayer)(x::Real) = S(x, 1.)
(S::SineLayer)(args) = S(args...)


Loss(q::Vector{T}, r::Vector{T}) where {T<:Real} = squaresum(q-r)
Loss(Q::Vector{T}, R::Vector{T}) where {T<:AbstractVector} = sum(Loss.(Q, R))

end # module



Main.dev

In [335]:
Param([1, 2, 3])

3-element Param{Array{Int64,1}}:
 1
 2
 3

In [336]:
supertype{}

LoadError: TypeError: in Type{...} expression, expected UnionAll, got a value of type typeof(supertype)

In [333]:
@doc Param

Usage:

```
x = Param([1,2,3])          # user declares parameters with `Param`
x => P([1,2,3])             # `Param` is just a struct wrapping a value
value(x) => [1,2,3]         # `value` returns the thing wrapped
sum(x .* x) => 14           # Params act like regular values
y = @diff sum(x .* x)       # Except when we differentiate using `@diff`
y => T(14)                  # you get another struct
value(y) => 14              # which carries the same result
params(y) => [x]            # and the Params that it depends on 
grad(y,x) => [2,4,6]        # and the gradients for all Params
```

`Param(x)` returns a struct that acts like `x` but marks it as a parameter you want to compute gradients with respect to.

`@diff expr` evaluates an expression and returns a struct that contains the result (which should be a scalar) and gradient information.

`grad(y, x)` returns the gradient of `y` (output by @diff) with respect to any parameter `x::Param`, or  `nothing` if the gradient is 0.

`value(x)` returns the value associated with `x` if `x` is a `Param` or the output of `@diff`, otherwise returns `x`.

`params(x)` returns an iterator of Params found by a recursive search of object `x`.

Alternative usage:

```
x = [1 2 3]
f(x) = sum(x .* x)
f(x) => 14
grad(f)(x) => [2 4 6]
gradloss(f)(x) => ([2 4 6], 14)
```

Given a scalar valued function `f`, `grad(f,argnum=1)` returns another function `g` which takes the same inputs as `f` and returns the gradient of the output with respect to the argnum'th argument. `gradloss` is similar except the resulting function also returns f's output.


In [327]:
c(t) = [cos(2π*t), sin(2π*t)]
γ(t) = 0.9t^2 + 0.1t 

c2 = c ∘ γ

#62 (generic function with 1 method)

In [328]:
q = dev.Qmap(c, 1e-4)
r = dev.Qmap(c2, 1e-4)

Main.dev.Qmap{Base.var"#62#63"{typeof(c),typeof(γ)},Float64}(Base.var"#62#63"{typeof(c),typeof(γ)}(c, γ), 0.0001)

In [346]:
X = range(0, 1, length=1024)

0.0:0.0009775171065493646:1.0

In [347]:
S = dev.SineLayer(zeros(2))

Main.dev.SineLayer{Array{Float64,1}}(2, P(Array{Float64,1}(2)))

In [371]:
myparams = @diff S.(X)

out = S.(X)

Q = q.(X)
R = [sqrt(y) * r(z) for (z, y) in out]

loss = @diff dev.Loss(Q, R)

6465.8123731737805

In [372]:
loss = @diff sum(R)

2-element Array{Float64,1}:
 190.211243502146
 233.40672876467266

In [377]:
gradloss = grad(dev.Loss)

(::AutoGrad.var"#gradfun#7"{AutoGrad.var"#gradfun#6#8"{typeof(Main.dev.Loss),Int64,Bool}}) (generic function with 1 method)

In [378]:
grad(loss)

LoadError: MethodError: no method matching grad(::Array{Float64,1})
Closest candidates are:
  grad(::Any, !Matched::Any) at /home/jorgen/.julia/packages/AutoGrad/VFrAv/src/core.jl:215
  grad(!Matched::AutoGrad.Tape, !Matched::AutoGrad.Tracked) at /home/jorgen/.julia/packages/AutoGrad/VFrAv/src/core.jl:216
  grad(!Matched::Function) at /home/jorgen/.julia/packages/AutoGrad/VFrAv/src/core.jl:219
  ...

In [373]:
params(loss)

Param[]