Skip to content

Commit

Permalink
fix load_state_dict error when loading class weights for multi-stage …
Browse files Browse the repository at this point in the history
…loss
  • Loading branch information
drprojects committed Jul 21, 2023
1 parent b46a9d2 commit 6b9ac9a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
4 changes: 4 additions & 0 deletions configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ defaults:
- extras: default.yaml
- hydra: default.yaml

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null

# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default.yaml
Expand Down
31 changes: 23 additions & 8 deletions src/models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,29 @@ def configure_optimizers(self):
"reduce_on_plateau": reduce_on_plateau}}

def load_state_dict(self, state_dict, strict=True):
# Little bit of acrobatics due to `criterion.weight`. This
# attribute, when present in the `state_dict`, causes
# `load_state_dict` to crash.
try:
super().load_state_dict(state_dict, strict=strict)
except:
class_weight = state_dict.pop('criterion.weight', None)
super().load_state_dict(state_dict, strict=strict)
"""Basic `load_state_dict` from `torch.nn.Module` with a little
bit of acrobatics due to `criterion.weight`.
This attribute, when present in the `state_dict`, causes
`load_state_dict` to crash. More precisely, `criterion.weight`
is holding the per-class weights for classification losses.
"""
# Recover the class weights from any 'criterion.weight' or
# 'criterion.*.weight' key and remove those keys from the
# state_dict
keys = []
for key in state_dict.keys():
if key.startswith('criterion.') and key.endswith('.weight'):
keys.append(key)
class_weight = state_dict[keys[0]] if len(keys) > 0 else None
for key in keys:
state_dict.pop(key)

# Load the state_dict
super().load_state_dict(state_dict, strict=strict)

# If need be, assign the class weights to the criterion
if class_weight is not None and hasattr(self.criterion, 'weight'):
self.criterion.weight = class_weight

@staticmethod
Expand Down

0 comments on commit 6b9ac9a

Please sign in to comment.