/
SpatialCrossEntropyCriterionWithIgnore.lua
110 lines (94 loc) · 4.23 KB
/
SpatialCrossEntropyCriterionWithIgnore.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
if isdefined_SpatialCrossEntropyCriterionWithIgnore then
return
end
isdefined_SpatialCrossEntropyCriterionWithIgnore = true
require 'nn'
local SpatialCrossEntropyCriterionWithIgnore, parent =
torch.class('nn.SpatialCrossEntropyCriterionWithIgnore', 'nn.Criterion')
--[[
This criterion does the SpatialCrossEntropyCriterionWithIgnore across
the feature dimension for a N-channel image of HxW in size.
It only supports mini-batches (4D input, 3D target)
It does a LogSoftMax on the input (over the channel dimension),
so no LogSoftMax is needed in the network at the end
input = batchSize x nClasses x H x W
target = batchSize x H x W
]]--
function SpatialCrossEntropyCriterionWithIgnore:__init(weights)
parent.__init(self)
self.slsm = cudnn.SpatialLogSoftMax()
self.nll = nn.ClassNLLCriterion(weights)
--self.sizeAverage = true
self.sizeAverage = false
end
local transpose = function(input)
input = input:transpose(2,4):transpose(2,3):contiguous() -- bdhw -> bwhd -> bhwd
input = input:view(input:size(1)*input:size(2)*input:size(3), input:size(4))
return input
end
local transposeBack = function(input, originalInput)
input = input:view(originalInput:size(1), originalInput:size(3),
originalInput:size(4), originalInput:size(2))
input = input:transpose(2,4):transpose(3,4):contiguous() -- bhwd -> bdwh -> bdhw
return input
end
function SpatialCrossEntropyCriterionWithIgnore:updateOutput(input, target)
assert(input:dim() == 4, 'mini-batch supported only')
assert(target:dim() == 3, 'mini-batch supported only')
assert(input:size(1) == target:size(1), 'input and target should be of same size')
assert(input:size(3) == target:size(2), 'input and target should be of same size')
assert(input:size(4) == target:size(3), 'input and target should be of same size')
-- apply SpatialLogSoftMax to input
self.slsm:updateOutput(input)
local ignore_mask = target:ne(0)
--print(target:eq(0):sum() / (target:size(2) * target:size(3)))
local num_classes = self.slsm.output:size(2)
for i = 1, num_classes do
--print(self.slsm.output[{{},i}]:size())
self.slsm.output[{{},i}]:cmul(ignore_mask)
--assert(self.slsm.output[{{},i}][ignore_mask:eq(0)]:sum() == 0)
end
-- Update submodule sizeAverage to make it consistent.
self.nll.sizeAverage = self.sizeAverage
--local tmp_target = target + target:eq(0):type("torch.IntTensor")
local tmp_target = target + target:eq(0)
-- fold the height and width dims into the mini-batch dim.
--self.nll:updateOutput(transpose(self.slsm.output), target:view(-1))
self.nll:updateOutput(transpose(self.slsm.output), tmp_target:view(-1))
self.output = self.nll.output
return self.output
end
function SpatialCrossEntropyCriterionWithIgnore:updateGradInput(input, target)
assert(input:dim() == 4, 'mini-batch supported only')
assert(target:dim() == 3, 'mini-batch supported only')
assert(input:size(1) == target:size(1), 'input and target should be of same size')
assert(input:size(3) == target:size(2), 'input and target should be of same size')
assert(input:size(4) == target:size(3), 'input and target should be of same size')
local tmp_target = target + target:eq(0)
self.nll:updateGradInput(transpose(self.slsm.output), tmp_target:view(-1))
--print(self.nll.gradInput[{{},1}]:size())
--print(target:view(-1):eq(0):size())
local ignore_mask = target:view(-1):ne(0)
local num_classes = self.nll.gradInput:size(2)
for i = 1, num_classes do
--print(self.nll.gradInput[{{},i}][target:view(-1):eq(0)]:sum())
self.nll.gradInput[{{},i}]:cmul(ignore_mask)
end
-- unfold the height and width dims back
self.slsm:updateGradInput(input, transposeBack(self.nll.gradInput, input))
self.gradInput = self.slsm.gradInput
--print(self.gradInput:size())
--for i = 1, num_classes do
-- print(self.gradInput[{1,i}][target:eq(0)]:sum())
-- --assert(self.gradInput[{1,i}][target:eq(0)]:sum() == 0)
--end
return self.gradInput
end
function SpatialCrossEntropyCriterionWithIgnore:type(type)
if type then
self.nll:type(type)
self.slsm:type(type)
end
parent.type(self, type)
return self
end