Skip to content

Commit

Permalink
further fix MultiLoss load_state_dict errors
Browse files Browse the repository at this point in the history
  • Loading branch information
drprojects committed Jul 26, 2023
1 parent 6b9ac9a commit 5558532
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
39 changes: 39 additions & 0 deletions src/loss/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,42 @@ def weight(self, weight):
"""
for i in range(len(self)):
self.criteria[i].weight = weight

def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
"""Normal `state_dict` behavior, except for the shared criterion
weights, which are not saved under `prefix.criteria.i.weight`
but under `prefix.weight`.
"""
destination = super().state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

# Remove the 'weight' from the criteria
for i in range(len(self)):
destination.pop(f"{prefix}criteria.{i}.weight")

# Only save the global shared weight
destination[f"{prefix}weight"] = self.weight

return destination

def load_state_dict(self, state_dict, strict=True):
"""Normal `load_state_dict` behavior, except for the shared
criterion weights, which are not saved under `criteria.i.weight`
but under `prefix.weight`.
"""
# Get the weight from the state_dict
old_format = state_dict.get('criteria.0.weight')
new_format = state_dict.get('weight')
weight = new_format if new_format is not None else old_format
for k in [f"criteria.{i}.weight" for i in range(len(self))]:
if k in state_dict.keys():
state_dict.pop(k)

# Normal load_state_dict, ignoring self.criteria.0.weight and
# self.weight
out = super().load_state_dict(state_dict, strict=strict)

# Set the weight
self.weight = weight

return out
30 changes: 18 additions & 12 deletions src/models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,23 +545,29 @@ def load_state_dict(self, state_dict, strict=True):
`load_state_dict` to crash. More precisely, `criterion.weight`
is holding the per-class weights for classification losses.
"""
# Recover the class weights from any 'criterion.weight' or
# 'criterion.*.weight' key and remove those keys from the
# state_dict
keys = []
for key in state_dict.keys():
if key.startswith('criterion.') and key.endswith('.weight'):
keys.append(key)
class_weight = state_dict[keys[0]] if len(keys) > 0 else None
for key in keys:
state_dict.pop(key)
# Special treatment for MultiLoss
if self.multi_stage_loss:
class_weight_bckp = self.criterion.weight
self.criterion.weight = None

# Recover the class weights from any 'criterion.weight' or
# 'criterion.*.weight' key and remove those keys from the
# state_dict
keys = []
for key in state_dict.keys():
if key.startswith('criterion.') and key.endswith('.weight'):
keys.append(key)
class_weight = state_dict[keys[0]] if len(keys) > 0 else None
for key in keys:
state_dict.pop(key)

# Load the state_dict
super().load_state_dict(state_dict, strict=strict)

# If need be, assign the class weights to the criterion
if class_weight is not None and hasattr(self.criterion, 'weight'):
self.criterion.weight = class_weight
if self.multi_stage_loss:
self.criterion.weight = class_weight if class_weight is not None \
else class_weight_bckp

@staticmethod
def sanitize_step_output(out_dict):
Expand Down

0 comments on commit 5558532

Please sign in to comment.