diff --git a/GANDLF/models/modelBase.py b/GANDLF/models/modelBase.py index 226c6f911..0fb628835 100644 --- a/GANDLF/models/modelBase.py +++ b/GANDLF/models/modelBase.py @@ -13,8 +13,10 @@ GlobalAveragePooling2D, ) +from huggingface_hub import PyTorchModelHubMixin -class ModelBase(nn.Module): + +class ModelBase(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/mlcommons/GaNDLF/", tags=["segmentation", "medical", "pytorch"]): """ This is the base model class that all other architectures will need to derive from """