Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

New args for KL Div and DistNLL

  • Loading branch information...
commit bfba91da04b9f23e56abe6077f296b96e424c79f 1 parent 63eb5e3
@clementfarabet authored
Showing with 12 additions and 8 deletions.
  1. +7 −5 DistNLLCriterion.lua
  2. +5 −3 KLDivCriterion.lua
View
12 DistNLLCriterion.lua
@@ -1,12 +1,14 @@
local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterion')
-function DistNLLCriterion:__init()
+function DistNLLCriterion:__init(opts)
parent.__init(self)
-- user options
- self.inputIsADistance = false
- self.inputIsProbability = false
- self.inputIsLogProbability = false
- self.targetIsProbability = false
+ opts = opts or {}
+ self.inputIsADistance = opts.inputIsADistance or false
+ self.inputIsProbability = opts.inputIsProbability or false
+ self.inputIsLogProbability = opts.inputIsLogProbability or false
+ self.targetIsProbability = opts.targetIsProbability
+ if self.targetIsProbability == nil then self.targetIsProbability = true end
-- internal
self.targetSoftMax = nn.SoftMax()
self.inputLogSoftMax = nn.LogSoftMax()
View
8 KLDivCriterion.lua
@@ -1,10 +1,12 @@
local KLDivCriterion, parent = torch.class('nn.KLDivCriterion', 'nn.Criterion')
-function KLDivCriterion:__init()
+function KLDivCriterion:__init(opts)
parent.__init(self)
-- user options
- self.inputIsProbability = false
- self.targetIsProbability = false
+ opts = opts or {}
+ self.inputIsProbability = opts.inputIsProbability or false
+ self.targetIsProbability = opts.targetIsProbability
+ if self.targetIsProbability == nil then self.targetIsProbability = true end
-- internal
self.targetSoftMax = nn.SoftMax()
self.inputSoftMax = nn.SoftMax()
Please sign in to comment.
Something went wrong with that request. Please try again.