Skip to content

Commit

Permalink
Small update
Browse files Browse the repository at this point in the history
  • Loading branch information
hshindo committed Jun 30, 2018
1 parent 3c4735d commit fb5e07f
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 67 deletions.
49 changes: 37 additions & 12 deletions src/functions/cnn/conv1d.jl
@@ -1,58 +1,83 @@
export Conv1d

doc"""
Conv1d(T, ksize, inchannel, outchannel, kwargs...)
Conv1d(T, ksize, inchannel, outchannel, [padding=0, stride=1, dilation=1])
1-dimensional convolution function.
```julia
T = Float32
x = Var(rand(T,10,5))
f = Conv1d(T, 5, 10, 3, pad=2)
f = Conv1d(T, 5, 10, 3, padding=2)
y = f(x)
```
"""
mutable struct Conv1d <: Functor
W::Var
b::Var
ksize::Int
pad::Int
padding::Int
stride::Int
dilation::Int
end

getparams(f::Conv1d) = (f.W, f.b)

function Conv1d(::Type{T}, ksize::Int, inchannel::Int, outchannel::Int;
pad=0, stride=1, dilation=1, init_W=Xavier(), init_b=Fill(0)) where T
padding=0, stride=1, dilation=1, init_W=Xavier(), init_b=Fill(0)) where T

W = init_W(T, ksize*inchannel, outchannel)
b = init_b(T, outchannel)
Conv1d(param(W), param(b), ksize, pad, stride, dilation)
Conv1d(param(W), param(b), ksize, padding, stride, dilation)
end

function (f::Conv1d)(x::Var, length::Vector{Int})
function (f::Conv1d)(xs::Vector{Var})
idxs = map(xs) do x
@assert size(x) == 2
conv1d_index(f, size(x,2))
end
idx = cat(2, idxs...)
x = concat(2, xs...)
h = lookup(x, Var(idx))

@assert ndims(x) == 2 && sum(length) == size(x,2)
idx = conv1d_index(f.ksize, f.pad, f.stride, f.dilation, length)
idx = conv1d_index(f.ksize, f.padding, f.stride, f.dilation, length)
h = lookup(x, Var(idx))
y = linear(h, f.W, f.b)
y
end
(f::Conv1d)(x::Var) = f(x, [size(x,2)])
(f::Conv1d)(xs::Vars) = f(Var(xs), size(xs,2))
(f::Conv1d)(x::Node) = Node(f, x)

function conv1d_index(ksize::Int, pad::Int, stride::Int, dilation::Int, batchsize::Vector{Int})
function conv1d_index(f::Conv1d, inlength::Int)
ksize, padding, stride, dilation = f.ksize, f.padding, f.stride, f.dilation
k = (ksize - 1) * dilation + 1
outlength = (inlength + 2padding - k) ÷ stride + 1

y = Array{Int}(ksize, outlength)
yi = 1
i = -padding + 1
for d = 1:outlength
j = i + (ksize-1)*dilation
for k = i:dilation:j
y[yi] = 0 < k <= inlength ? k : 0
yi += 1
end
i += stride
end
y
end

function conv1d_index(ksize::Int, padding::Int, stride::Int, dilation::Int, batchsize::Vector{Int})
outdims = map(batchsize) do d
k = (ksize - 1) * dilation + 1
(d + 2pad - k) ÷ stride + 1
(d + 2padding - k) ÷ stride + 1
end
cumdim = 0
y = Array{Int}(ksize, sum(outdims))
yi = 1
for n = 1:length(batchsize)
ndims = batchsize[n]
i = cumdim - pad + 1
i = cumdim - padding + 1
for d = 1:outdims[n]
for j = i:dilation:i+(ksize-1)*dilation
xi = cumdim < j <= cumdim+ndims ? j : 0
Expand Down
7 changes: 3 additions & 4 deletions src/functions/concat.jl
Expand Up @@ -12,19 +12,18 @@ x2 = Var(rand(T,4,5))
y = concat(2, x1, x2)
```
"""
function concat(dim::Int, xs::Vector{Var})
function concat(dim::Int, xs::Var...)
configure!(xs...)
y = cat(dim, map(x -> x.data, xs)...)
Var(y, (concat,dim,xs))
end
concat(dim::Int, xs::Var...) = concat(dim, [xs...])
concat(dim::Int, xs::Node...) = Node(concat, dim, xs...)

function addgrad!(y::Var, ::typeof(concat), dim::Int, xs::Vector{Var})
function addgrad!(y::Var, ::typeof(concat), dim::Int, xs::Tuple)
∇concat!(y, dim, xs)
end

function ∇concat!(y::Var, dim::Int, xs::Vector{Var})
function ∇concat!(y::Var, dim::Int, xs::Tuple)
offset = 0
for x in xs
s = size(x, dim)
Expand Down
52 changes: 15 additions & 37 deletions src/functions/split.jl
@@ -1,7 +1,5 @@
export unsafe_split

doc"""
split(x::Var, dim::Int, dims::Vector{Int})
split(x::Var, dim::Int, size::Vector{Int})
# Example
```julia
Expand All @@ -10,42 +8,22 @@ x = Var(rand(T,10,10))
ys = split(x, 2, [2,3,5])
```
"""
function Base.split(x::Var, size::Vector)
function Base.split(x::Var, dim::Int, size::Vector{Int})
@assert sum(size) == Base.size(x,dim)
if dim == ndims(x)
copy()
cumdim = 0
front = Base.front(Base.size(x))
m = prod(front)
ys = Var[]
for s in size
range = cumdim+1:cumdim+s
data = view(x.data, front..., range)
y = Var(data, (x,dim,range))
push!(ys, y)
cumdim += s
end
else

end

cumdim = 0
for d in dims
y = x[]
cumdim += d
end
ys

@assert dim == ndims(x)
@assert sum(dims) == size(x,dim)
front = Base.front(size(x))
cumdim = 0
ys = Var[]
for d in dims
y = x[front...,cumdim+1:cumdim+d]
push!(ys, y)
cumdim += d
end
ys
end

function Base.split(x::Var, size::Vector)
ys = Var[]
offset = 0
for i = 1:length(size)
s = size[i]
y = view(x.data, I...)
y = Var(y, (split,x,size,i))
push!(ys, y)
offset += prod(s)
throw("Not implemented yet.")
end
ys
end
Expand Down
16 changes: 7 additions & 9 deletions src/var.jl
@@ -1,4 +1,5 @@
export Var, param, zerograd!, isvoid, isparam, gradient!, topsort, create_batch
export Var
export param, zerograd!, isvoid, isparam, gradient!, topsort, create_batch

doc"""
Var
Expand All @@ -19,16 +20,13 @@ x = zerograd(rand(T,10,5)) # x.grad is initialized as zero.
"""
mutable struct Var
data
size
args
grad
end

Var(data::Array) = Var(data, size(data))
Var(data, size) = Var(data, size, ())
Var(data, size, args) = Var(data, size, args, nothing)
Var(data, args=(); grad=nothing) = Var(data, args, grad)

function param(data::Array)
function param(data)
v = Var(data)
v.grad = zeros(data)
v
Expand All @@ -49,10 +47,10 @@ function concat(x::Var)
Var(x.data, size, x.args, x.grad)
end

Base.size(x::Var) = x.size
Base.size(x::Var, i::Int) = i <= ndims(x) ? x.size[i] : 1
Base.size(x::Var) = size(x.data)
Base.size(x::Var, i::Int) = size(x.data, i)
Base.length(x::Var) = length(x.data)
Base.ndims(x::Var) = length(x.size)
Base.ndims(x::Var) = ndims(x.data)
Base.eltype(x::Var) = eltype(x.data)
Base.getindex(x::Var, i::Int) = x.args[i]
isvoid(x) = x == nothing
Expand Down
25 changes: 21 additions & 4 deletions src/vars.jl
@@ -1,11 +1,28 @@
export Vars

mutable struct Vars
var::Var
size::Tuple
dims::Tuple
end

function Vars(xs::Vector{Var})
y = concat(dim, xs...)
dims = ntuple(ndims(y)) do i
i == dim ? map(x -> size(x,dim), xs) : size(y,i)
end
Vars(y, dims)
end
function Vars(xs::Vector{Array{T,N}}) where {T,N}
y = cat(dim, xs...)
dims = ntuple(ndims(y)) do i
i == dim ? map(x -> size(x,dim), xs) : size(y,i)
end
Vars(Var(y), dims)
end

Base.size(x::Vars) = x.size
Base.size(x::Vars, i::Int) = i <= length(x.size) ? x.size[i] : 1
Base.ndims(x::Vars) = length(x.size)
Base.size(x::Vars) = x.dims
Base.size(x::Vars, i::Int) = i <= ndims(x) ? x.dims[i] : 1
Base.ndims(x::Vars) = length(x.dims)

function Var(x::Vars)
for s in x.size
Expand Down
2 changes: 1 addition & 1 deletion test/functions.jl
Expand Up @@ -76,7 +76,7 @@ end
##### conv1d #####
x = param(randn(T,20,10))
batchdims = [3,7]
conv = Conv1d(T, 5, 20, 15, pad=2)
conv = Conv1d(T, 5, 20, 15, padding=2)
@test_grad conv x batchdims
@test_cuda conv x batchdims

Expand Down

0 comments on commit fb5e07f

Please sign in to comment.