Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

[API] Design approach for data metrics dependency #14

Closed
SeanNaren opened this issue Jan 5, 2021 · 2 comments
Closed

[API] Design approach for data metrics dependency #14

SeanNaren opened this issue Jan 5, 2021 · 2 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@SeanNaren
Copy link
Contributor

There are metrics that rely on the datamodule. For example, text classification requires the model knowing the number of classes which is tied to the dataset.

This is currently passed as below:

    model: LitAutoModelTransformer = instantiate_downstream_model(
        ...
        **data_module.data_model_kwargs
    )

Where data_module.data_model_kwargs is:

class LitTextClassificationDataModule(LitTransformerDataModule):
    @property
    def num_classes(self):
        return self.labels.num_classes

    @property
    def data_model_kwargs(self):
        return {
            'num_classes': self.num_classes
        }

This isn't very easy to grok imo. The kwargs can be anything, making it hard to tell what is being passed to the module in the first place.

Potential solution 1

Pass the data module to the model (I don't think save_hyperparameter works in this instance though):

class LitAutoModelTextClassificationTransformer(LitAutoModelTransformer):
    def __init__(self,
                 downstream_model_type: str,
                 backbone: DictConfig,
                 optim: DictConfig,
                 datamodule: LitTextClassificationDataModule,
                 scheduler: Optional[DictConfig] = None):
        self.num_classes = num_classes
        super().__init__(
            downstream_model_type=downstream_model_type,
            backbone=backbone,
            optim=optim,
            scheduler=scheduler
        )
        self._initialize_metrics(num_classes=datamodule.num_classes) # use num classes from datamodule directly

Potential solution 2

Initialize metrics when we know we can since the trainer now contains a reference to the dataset:

class LitAutoModelTextClassificationTransformer(LitAutoModelTransformer):
    def on_fit_start(self):
        datamodule = self.trainer.datamodule
        self._initialize_metrics(num_classes=datamodule.num_classes)

We could add this hook within the base class and expose initialize_metrics with the datamodule, for example.

@SeanNaren SeanNaren added enhancement New feature or request help wanted Extra attention is needed labels Jan 5, 2021
@SeanNaren SeanNaren self-assigned this Jan 7, 2021
@SeanNaren
Copy link
Contributor Author

After discussion with @carmocca we've agreed on proposal 2. From Carlos:

class LitAutoModelTextClassificationTransformer(LitAutoModelTransformer):
    @property
    def num_classes():
        return self.trainer.datamodule.num_classes
    def self.configure_metrics():
        self.metric1 = metric(self.num_classes)
    def on_fit_start(self):
        self.configure_metrics()

class SomeoneElsesDataModule(LitTransformerDataModule):
    @property
    def C():
        # this user chose to set num_classes as C
        ...

# If necessary, one could create this "adapter" without rewriting configure_metrics
class AdaptedTransformer(LitAutoModelTextClassificationTransformer):
    @property
    def num_classes():
        return self.trainer.datamodule.C

@SeanNaren
Copy link
Contributor Author

Closed with #17. The above has been added but not in it's entirety. We still have data config args passed from the datamodule for autoconfig to load num_labels for example.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant