Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization with shared reference leads to averaging the losses after each epoch. #2

Closed
gabrielriqu3ti opened this issue Sep 9, 2021 · 1 comment

Comments

@gabrielriqu3ti
Copy link

Hey,

I believe I have found a bug in this project.

When I train a GLIA-Net network, the total, the local and the global average losses are all equal after each epoch.

An example follows:

2021-09-04 23:39:19 [MainThread] INFO [TaskAneurysmSegTrainer] - (Time epoch: 6081.78)train epoch 57/66 finished. total_loss_avg: 0.1874 local_loss_avg: 0.1874 global_loss_avg: 0.1874 ap: 0.1771 auc: 0.9058 precision: 0.6890 recall: 0.0313 dsc: 0.0598 hd95: 20.5659 per_target_precision: 0.0385 per_target_recall: 0.0057

I believe the problem originates from the initialization of the OrderedDicts avg_losses and eval_avg_losses , created in the following lines:

avg_losses = OrderedDict(zip(list(losses.keys()), [RunningAverage()] * len(losses)))

eval_avg_losses = OrderedDict(zip(list(losses.keys()), [RunningAverage()] * len(losses)))

The exact problem is that the list containing the 3 RunningAverages are created using the notation [a]*n. This notation uses a shared reference for the 3 RunningAverages. Therefore, when we try to update a RunningAverage, all three losses are updated together.

A solution for this problem would be to use this initialization [RunningAverage() for _ in range(len(losses))] instead of [RunningAverage()] * len(losses).

This solution seams to work for me.

An example follows:

2021-09-08 20:58:40 [MainThread] INFO [TaskAneurysmSegTrainer] - (Time epoch: 6101.42)train epoch 86/86 finished. total_loss_avg: 0.2528 local_loss_avg: 0.2200 global_loss_avg: 0.0328 ap: 0.4419 auc: 0.9313 precision: 0.5942 recall: 0.3963 dsc: 0.4755 hd95: 12.1462 per_target_precision: 0.1321 per_target_recall: 0.1637

And as we can see the total average loss is equal to the sum of the local and global average losses.

Great work again!

Best regards,

Gabriel

MeteorsHub added a commit that referenced this issue Sep 9, 2021
@MeteorsHub
Copy link
Owner

MeteorsHub commented Sep 9, 2021

Hi, Gabriel.

Thanks you for the bug report. I have fixed this bug according to your suggestions.

Best regards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants