In [1]:
import Pkg
Pkg.activate("..")


[32m[1m  Activating[22m[39m project at `~/Projects/latentplan.jl`


In [3]:
using Knet
using CUDA

In [48]:
import Knet.Ops20.relu
import Base.Math.clamp

In [16]:
function KnetArray(x::CuArray{T,N}) where {T,N}
    p = Base.bitcast(Knet.Cptr, pointer(x))
    k = Knet.KnetPtr(p, sizeof(x), Int(device().handle), x) 
    KnetArray{T,N}(k, size(x))
end

KnetArray

In [26]:
x = KnetArray{Float32}(rand(10,10))

10×10 Knet.KnetArrays.KnetMatrix{Float32}:
 0.683049   0.766974   0.619906  0.796985  …  0.968252  0.498452  0.915062
 0.527889   0.446841   0.683464  0.201916     0.344828  0.390409  0.866548
 0.561447   0.116531   0.327135  0.367339     0.173769  0.576912  0.390739
 0.613648   0.0259174  0.488162  0.709622     0.18652   0.704081  0.591589
 0.0525429  0.131312   0.946495  0.70731      0.523747  0.908029  0.0214718
 0.736442   0.0560524  0.47045   0.538126  …  0.855571  0.304281  0.394135
 0.456228   0.969837   0.753539  0.595684     0.164759  0.705968  0.18111
 0.55199    0.822626   0.81882   0.603223     0.195996  0.72688   0.210995
 0.630951   0.328847   0.752974  0.269828     0.982766  0.377089  0.127255
 0.925757   0.341735   0.966223  0.775777     0.5398    0.481439  0.718391

In [18]:
KnetArray(clamp.(CuArray(x), 0.5f0, 1.0f0))

10×10 Knet.KnetArrays.KnetMatrix{Float32}:
 0.5       0.525387  0.5       0.5       …  0.5       0.538931  0.5
 0.786091  0.5       0.5       0.5          0.888713  0.5       0.5
 0.5       0.5       0.612352  0.858553     0.5       0.5       0.813512
 0.508942  0.5       0.846685  0.5          0.596829  0.767569  0.5
 0.5       0.523146  0.606859  0.651397     0.789887  0.5       0.5
 0.5       0.647527  0.917921  0.517445  …  0.5       0.5       0.5
 0.974556  0.5       0.921046  0.5          0.5       0.738392  0.834936
 0.616743  0.5       0.5       0.619019     0.5       0.5       0.668229
 0.5       0.5       0.5       0.5          0.5       0.777226  0.5
 0.5       0.5       0.5       0.774463     0.939926  0.919856  0.720092

In [23]:
y = rand(Float32,10,10)

10×10 Matrix{Float32}:
 0.819375   0.579572   0.00104445  …  0.600373  0.00758874  0.642988
 0.907414   0.0938323  0.20678        0.262016  0.90714     0.777946
 0.469722   0.591381   0.97541        0.708113  0.0651411   0.894889
 0.709947   0.267686   0.68665        0.744418  0.674308    0.00481975
 0.808484   0.526133   0.042076       0.233976  0.171335    0.922951
 0.0931091  0.380405   0.972814    …  0.37886   0.837018    0.972179
 0.9631     0.952596   0.489312       0.123626  0.818212    0.334196
 0.874123   0.158154   0.0999919      0.552401  0.228191    0.907283
 0.710274   0.319345   0.691571       0.195759  0.0946238   0.435144
 0.219605   0.425619   0.105131       0.487798  0.962756    0.0262671

In [24]:
clamp.(y, 0.5f0, 1.0f0)

10×10 Matrix{Float32}:
 0.819375  0.579572  0.5       0.5       …  0.600373  0.5       0.642988
 0.907414  0.5       0.5       0.5          0.5       0.90714   0.777946
 0.5       0.591381  0.97541   0.5          0.708113  0.5       0.894889
 0.709947  0.5       0.68665   0.592469     0.744418  0.674308  0.5
 0.808484  0.526133  0.5       0.5          0.5       0.5       0.922951
 0.5       0.5       0.972814  0.981097  …  0.5       0.837018  0.972179
 0.9631    0.952596  0.5       0.5          0.5       0.818212  0.5
 0.874123  0.5       0.5       0.531506     0.552401  0.5       0.907283
 0.710274  0.5       0.691571  0.5          0.5       0.5       0.5
 0.5       0.5       0.5       0.5          0.5       0.962756  0.5

In [35]:
function clip(x::T, lo, hi) where {T<:Real}
    x >= hi ? hi : x <= lo ? lo : x 
end

clip (generic function with 2 methods)

In [36]:
clip.(x, 0.5f0, 1.0f0)

LoadError: TypeError: non-boolean (Knet.KnetArrays.Bcasted{Knet.KnetArrays.KnetMatrix{Float32}}) used in boolean context

In [33]:
function relu(x::T; max_value=Inf, negative_slope=0, threshold=0) where {T<:Real}
    (x >= max_value ? oftype(x, max_value) :
     x >= threshold ? x :
     negative_slope == 0 ? zero(T) :
     negative_slope * (x - oftype(x, threshold)))
end

relu (generic function with 7 methods)

In [34]:
relu.(x)

10×10 Knet.KnetArrays.KnetMatrix{Float32}:
 0.683049   0.766974   0.619906  0.796985  …  0.968252  0.498452  0.915062
 0.527889   0.446841   0.683464  0.201916     0.344828  0.390409  0.866548
 0.561447   0.116531   0.327135  0.367339     0.173769  0.576912  0.390739
 0.613648   0.0259174  0.488162  0.709622     0.18652   0.704081  0.591589
 0.0525429  0.131312   0.946495  0.70731      0.523747  0.908029  0.0214718
 0.736442   0.0560524  0.47045   0.538126  …  0.855571  0.304281  0.394135
 0.456228   0.969837   0.753539  0.595684     0.164759  0.705968  0.18111
 0.55199    0.822626   0.81882   0.603223     0.195996  0.72688   0.210995
 0.630951   0.328847   0.752974  0.269828     0.982766  0.377089  0.127255
 0.925757   0.341735   0.966223  0.775777     0.5398    0.481439  0.718391

In [39]:
max.(min.(0.7f0, x), 0.5)

10×10 Knet.KnetArrays.KnetMatrix{Float32}:
 0.683049  0.7  0.619906  0.7       0.5       …  0.7       0.5       0.7
 0.527889  0.5  0.683464  0.5       0.7          0.5       0.5       0.7
 0.561447  0.5  0.5       0.5       0.5          0.5       0.576912  0.5
 0.613648  0.5  0.5       0.7       0.7          0.5       0.7       0.591589
 0.5       0.5  0.7       0.7       0.5          0.523747  0.7       0.5
 0.7       0.5  0.5       0.538126  0.5       …  0.7       0.5       0.5
 0.5       0.7  0.7       0.595684  0.683364     0.5       0.7       0.5
 0.55199   0.7  0.7       0.603223  0.7          0.5       0.7       0.5
 0.630951  0.5  0.7       0.5       0.529006     0.7       0.5       0.5
 0.7       0.5  0.7       0.7       0.7          0.5398    0.5       0.7

In [49]:
function clamp(x::KnetArray{T}, lo::T, hi::T) where T
    return max.(min.(hi,x),lo)
end

clamp (generic function with 12 methods)

In [50]:
clamp(x, 0.5f0, 0.7f0)

10×10 Knet.KnetArrays.KnetMatrix{Float32}:
 0.683049  0.7  0.619906  0.7       0.5       …  0.7       0.5       0.7
 0.527889  0.5  0.683464  0.5       0.7          0.5       0.5       0.7
 0.561447  0.5  0.5       0.5       0.5          0.5       0.576912  0.5
 0.613648  0.5  0.5       0.7       0.7          0.5       0.7       0.591589
 0.5       0.5  0.7       0.7       0.5          0.523747  0.7       0.5
 0.7       0.5  0.5       0.538126  0.5       …  0.7       0.5       0.5
 0.5       0.7  0.7       0.595684  0.683364     0.5       0.7       0.5
 0.55199   0.7  0.7       0.603223  0.7          0.5       0.7       0.5
 0.630951  0.5  0.7       0.5       0.529006     0.7       0.5       0.5
 0.7       0.5  0.7       0.7       0.7          0.5398    0.5       0.7