Skip to content

Commit

Permalink
add TemporalConvolution2.lua, add test for TemporlaConvolution2, add …
Browse files Browse the repository at this point in the history
…failing test for SpatialConvolution, factorize im2col out of SpatialConvolutionMM, in line with cunn
  • Loading branch information
hughperkins committed Mar 26, 2016
1 parent 3232a17 commit 53f3a3f
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 220 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ set(src init.cpp utils.cpp
)

set(luasrc init.lua MSECriterion.lua Pointwise.lua Threshold.lua LookupTable.lua
LogSoftMax.lua ClassNLLCriterion.lua StatefulTimer.lua THCLNN.lua
LogSoftMax.lua ClassNLLCriterion.lua StatefulTimer.lua THCLNN.lua TemporalConvolution2.lua
Narrow.lua CMulTable.lua test.lua test/testSpatialMaxPooling.lua test/testSpatialConvolutionMM.lua
test/testLookupTable.lua test/testMSECriterion.lua test/testSpatialUpSamplingNearest.lua
test/testELU.lua test/testhelpers.lua
test/testELU.lua test/testhelpers.lua test/testTemporalConvolution2.lua
test/testClassNLLCriterion.lua test/testSoftMax.lua test/testLogSoftMax.lua test/testSpatialAveragePooling.lua)

ADD_TORCH_PACKAGE(clnn "${src}" "${luasrc}" )
Expand Down
77 changes: 77 additions & 0 deletions TemporalConvolution2.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
--[[
Compared to the base TemporalConvolution, this is:
- faster on GPUs (both in CUDA and in OpenCL)
- slower on CPUs
- the weights saved to file are incompatible with the original TemporalConvolution
Conceptually, it's a wrapper around the highly optimized SpatialConvolutionMM class
Just use it like TemporalConvolution, only with a '2' added to the class name. Easy :-)
]]--

require 'torch'
require 'nn'

local TemporalConvolution2, parent = torch.class('nn.TemporalConvolution2', 'nn.Module')

function TemporalConvolution2:__init(inputFrameSize, outputFrameSize, kW, dW, padW)
parent.__init(self)

self.inputFrameSize = inputFrameSize
self.outputFrameSize = outputFrameSize
self.kW = kW
self.dW = dW or 1
self.padW = padW or 0
self.sconv = nn.SpatialConvolution(inputFrameSize, outputFrameSize, 1, kW, 1, dW, 0, self.padW)
self.weight = self.sconv.weight
self.bias = self.sconv.bias
self.gradWeight = self.sconv.gradWeight
self.gradBias = self.sconv.gradBias
end

function TemporalConvolution2:clearState()
self.sconv:clearState()
parent:clearState()
end

function TemporalConvolution2:updateOutput(input)
assert(input:dim() == 3, 'must provide batched input')
local batchSize = input:size(1)
local numFrames = input:size(2)
local outFrames = numFrames - math.floor(self.kW/2)*2 + 2 * self.padW
if self.kW%2 == 0 then outFrames = outFrames+1 end

input = input:view(batchSize, numFrames, self.inputFrameSize, 1):transpose(2,3)
local output = self.sconv:updateOutput(input):transpose(2,3)
self.output:resize(batchSize, outFrames, self.outputFrameSize):copy(output)
return self.output
end

function TemporalConvolution2:updateGradInput(input, gradOutput)
assert(input:dim() == 3, 'must provide batched input')
local batchSize = input:size(1)
local numFrames = input:size(2)
local outFrames = numFrames - math.floor(self.kW/2)*2 + 2 * self.padW
if self.kW%2 == 0 then outFrames = outFrames+1 end

input = input:view(batchSize, numFrames, self.inputFrameSize, 1):transpose(2,3)
gradOutput = gradOutput:view(batchSize, outFrames, self.outputFrameSize, 1):transpose(2,3)
local gradInput = self.sconv:updateGradInput(input, gradOutput):transpose(2,3)
self.gradInput:resize(batchSize, numFrames, self.inputFrameSize):copy(gradInput)

return self.gradInput
end

function TemporalConvolution2:accGradParameters(input, gradOutput, scale)
assert(input:dim() == 3, 'must provide batched input')
local batchSize = input:size(1)
local numFrames = input:size(2)
local outFrames = numFrames - math.floor(self.kW/2)*2 + 2 * self.padW
if self.kW%2 == 0 then outFrames = outFrames+1 end

input = input:view(batchSize, numFrames, self.inputFrameSize, 1):transpose(2,3)
gradOutput = gradOutput:view(batchSize, outFrames, self.outputFrameSize, 1):transpose(2,3)
self.sconv:accGradParameters(input, gradOutput, scale)
end
2 changes: 2 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ end

torch.ClTensor.nn = {}

include 'TemporalConvolution2.lua'

include 'LookupTable.lua'
include 'Pointwise.lua'
include 'Threshold.lua'
Expand Down
2 changes: 1 addition & 1 deletion lib/THCLNN/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ENDIF()

#FILE(GLOB src-cl *.cpp)
set(src-cl Abs.cpp SpatialConvolutionMM.cpp ELU.cpp SpatialAveragePooling.cpp SpatialMaxPooling.cpp
SoftMax.cpp Tanh.cpp common.cpp SpatialUpSamplingNearest.cpp
SoftMax.cpp Tanh.cpp common.cpp SpatialUpSamplingNearest.cpp im2col.cpp
)

add_library(THCLNN MODULE ${src-cl})
Expand Down
Loading

0 comments on commit 53f3a3f

Please sign in to comment.