Skip to content
This repository has been archived by the owner on Jan 13, 2022. It is now read-only.

Commit

Permalink
Pushing internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ebetica committed Jun 3, 2016
1 parent 5dc9bb6 commit 5833c85
Show file tree
Hide file tree
Showing 42 changed files with 3,831 additions and 424 deletions.
2 changes: 1 addition & 1 deletion PATENTS
Expand Up @@ -30,4 +30,4 @@ necessarily infringed by the Software standing alone.

A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
or contributory infringement or inducement to infringe any patent, including a
cross-claim or counterclaim.
cross-claim or counterclaim.
32 changes: 28 additions & 4 deletions fbnn/CachingLookupTable.lua
Expand Up @@ -9,6 +9,29 @@ local util = require('fb.util')

local dprint = require('fb.util.dbg').new('CLuT')

-- return weight[row] by reference
local function readLutRow(self, row)
return self.weight:select(1, row)
end

-- weight[row] := val
local function writeLutRow(self, row, val)
local lval = self.weight:select(1, row)
lval:copy(val)
end

-- weight[row] += val
local function updateLutRow(self, row, val, alpha)
local alpha = alpha or 1.0
return self.weight:select(1, row):add(alpha, val)
end

local function updateLutRows(self, rows, val, alpha)
for i = 1,rows:size(1) do
updateLutRow(self, rows[i], val[i], alpha)
end
end

-- A way is a fully associative portion of the cache, with fixed
-- capacity. Since we search it by brute-force, it needs to be
-- modestly sized.
Expand Down Expand Up @@ -69,8 +92,8 @@ function Way:_writeBackOne(row)
if self.bufferedGrads[row] then
assert(self.rows[row]) -- invariant
dprint("updating row", row)
self.backing:updateRow(row, self.bufferedGrads[row])
dprint("after update", row, self.backing:readRow(row))
updateLutRow(self.backing, row, self.bufferedGrads[row])
dprint("after update", row, readLutRow(self.backing, row))
self.bufferedGrads[row] = nil
end
assert(not self.bufferedGrads[row])
Expand Down Expand Up @@ -132,12 +155,13 @@ function Way:pull(row)
self:_incStat('miss')
self:trim()
assert(self.numRows < self.size)
self.rows[row] = self.backing:readRow(row):clone()
self.rows[row] = readLutRow(self.backing, row):clone()
self.numRows = self.numRows + 1
assert(not self.bufferedGrads[row])
return self.rows[row]
end


-- The lookup table itself is a hash table of Ways.
local CachingLookupTable, parent = torch.class('nn.CachingLookupTable',
'nn.Module')
Expand Down Expand Up @@ -226,7 +250,7 @@ function CachingLookupTable:updateRow(row, val, lr)
end

function CachingLookupTable:updateRows(rows, val)
self.backing:updateRows(rows, val)
updateLutRows(self.backing, rows, val)
end

function CachingLookupTable:updateOutput(input)
Expand Down
4 changes: 2 additions & 2 deletions fbnn/ClassNLLCriterionWithUNK.lua
Expand Up @@ -33,7 +33,7 @@ function ClassNLLCriterionWithUNK:updateOutput(input, target)
local n = 0
if input:dim() == 1 then
if ((type(target) == 'number') and (target ~= self.unk_index)) or
((type(target) ~= 'number') and (taget[1] ~= self.unk_index))
((type(target) ~= 'number') and (target[1] ~= self.unk_index))
then
self.output = self.crit:updateOutput(input, target)
n = 1
Expand Down Expand Up @@ -71,7 +71,7 @@ end
function ClassNLLCriterionWithUNK:updateGradInput(input, target)
if input:dim() == 1 then
if ((type(target) == 'number') and (target ~= self.unk_index)) or
((type(target) ~= 'number') and (taget[1] ~= self.unk_index))
((type(target) ~= 'number') and (target[1] ~= self.unk_index))
then
self.gradInput = self.crit:updateGradInput(input, target)
end
Expand Down
29 changes: 29 additions & 0 deletions fbnn/Constant.lua
@@ -0,0 +1,29 @@
------------------------------------------------------------------------
--[[ Constant ]]--
-- author : Nicolas Leonard
-- Outputs a constant value given an input.
------------------------------------------------------------------------
local Constant, parent = torch.class("fbnn.Constant", "nn.Module")

function Constant:__init(value)
self.value = value
if torch.type(self.value) == 'number' then
self.value = torch.Tensor{self.value}
end
assert(torch.isTensor(self.value), "Expecting number or tensor at arg 1")
parent.__init(self)
end

function Constant:updateOutput(input)
-- "input:size(1)"" makes the assumption that you're in batch mode
local vsize = self.value:size():totable()
self.output:resize(input:size(1), table.unpack(vsize))
local value = self.value:view(1, table.unpack(vsize))
self.output:copy(value:expand(self.output:size()))
return self.output
end

function Constant:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
return self.gradInput
end

0 comments on commit 5833c85

Please sign in to comment.