Skip to content

Commit

Permalink
keep track of initial output for easy comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
civodlu committed Feb 9, 2020
1 parent ef1110f commit 545b1da
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/trw/train/meaningful_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __call__(self, inputs, target_class_name, target_class=None):
break
output = MeaningfulPerturbation._get_output(target_class_name, outputs, self.model_output_postprocessing)
logger.info('original model output={}'.format(utilities.to_value(output)))
output_start = utilities.to_value(output)

if target_class is None:
target_class = torch.argmax(output, dim=1)
Expand Down Expand Up @@ -251,15 +252,17 @@ def __call__(self, inputs, target_class_name, target_class=None):
c = output[:, target_class]
loss = l1 + tv + c

if i == 0:
# must be collected BEFORE backward!
c_start = utilities.to_value(c)

loss.backward()
optimizer.step()
scheduler.step()

# Optional: clamping seems to give better results
mask.data.clamp_(0, 1)

if i == 0:
c_start = utilities.to_value(c)

if i % 20 == 0:
logger.info('iter={}, total_loss={}, l1_loss={}, tv_loss={}, c_loss={}'.format(
Expand All @@ -282,6 +285,7 @@ def __call__(self, inputs, target_class_name, target_class=None):
'smoothed_input': utilities.to_value(blurred_img),
'loss_c_start': c_start,
'loss_c_end': utilities.to_value(c),
'output_start': output_start,
'output_end': utilities.to_value(output),
}

Expand Down

0 comments on commit 545b1da

Please sign in to comment.