In [2]:
import monai.losses as losses
import torchgeometry as tgm


In [3]:
class LossFactory:
    def __init__(self):
        self.losses = {
            'DiceLoss': losses.DiceLoss,
            'MaskedDiceLoss': losses.MaskedDiceLoss,
            'GeneralizedDiceLoss': losses.GeneralizedDiceLoss,
            'GeneralizedWassersteinDiceLoss': losses.GeneralizedWassersteinDiceLoss,
            'DiceCELoss': losses.DiceCELoss,
            'DiceFocalLoss': losses.DiceFocalLoss
        }
    
    def create_loss(self, loss_dict):
        loss_name = loss_dict.pop('name')
        if loss_name not in self.losses:
            raise ValueError(f"Unknown loss function name: {loss_name}")
        return self.losses[loss_name](**loss_dict)

In [7]:
factory = LossFactory()
loss_dict = {'name': 'DiceCELoss', 'include_background': True, 'reduction': 'mean','lambda_ce':1, 'lambda_dice':1}
dice_loss = factory.create_loss(loss_dict)

In [8]:
with open('optimizers.yaml', 'r') as f:
    optimizer_config = yaml.safe_load(f)

optimizer_factory = OptimizerFactory()
optimizer = optimizer_factory.create_optimizer(optimizer_config['weightedadam'])
print(optimizer)

DiceCELoss(
  (dice): DiceLoss()
  (cross_entropy): CrossEntropyLoss()
)