Skip to content

Commit

Permalink
Remove unneeded get_reduction_factor function
Browse files Browse the repository at this point in the history
  • Loading branch information
F. Dangel committed Apr 23, 2020
1 parent 2fb769d commit 3081690
Showing 1 changed file with 1 addition and 36 deletions.
37 changes: 1 addition & 36 deletions test/test_sum_hessian.py
Expand Up @@ -83,48 +83,13 @@

def autograd_sum_hessian(layer, input, targets):
"""Compute the Hessian of a loss module w.r.t. its input."""

def get_reduction_factor(layer, input, targets):
"""Determine reduction factor of individual losses.
Take the first sample and its label, clone it N times,
compare the individual loss with the loss of that batch.
For an `[N, D]` input:
CrossEntropyLoss(reduction="mean") -> N
CrossEntropyLoss(reduction="sum") -> 1
MSELoss(reduction="mean") -> N * D
MSELoss(reduction="sum") -> 1
"""
N = input.shape[0]

sample = input[0].unsqueeze(0)
repeat = [N] + [1 for _ in input.shape[1:]]
sample_repeated = sample.repeat(*repeat)

target = targets[0].unsqueeze(0)
repeat = [N] + [1 for _ in targets.shape[1:]]
target_repeated = target.repeat(*repeat)

sample_loss = layer(sample, target)
batch_loss = layer(sample_repeated, target_repeated)

factor = torch.round(N * sample_loss / batch_loss).item()
print("Factor", factor)

return factor

factor = get_reduction_factor(layer, input, targets)
factor = 1

input.requires_grad = True
loss = layer(input, targets)
hessian = autograd_hessian(loss, input)

sum_hessian = sum_hessian_blocks(hessian, input)

return sum_hessian / factor
return sum_hessian


def sum_hessian_blocks(hessian, input):
Expand Down

0 comments on commit 3081690

Please sign in to comment.