Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace train Function with Trainer Class #174

Closed
PicoCentauri opened this issue May 3, 2024 · 18 comments
Closed

Replace train Function with Trainer Class #174

PicoCentauri opened this issue May 3, 2024 · 18 comments

Comments

@PicoCentauri
Copy link
Contributor

PicoCentauri commented May 3, 2024

We are considering replacing the train function for each model architecture with a
Trainer class. This change aims not only to improve our code structure but also to pave
the way for integrating PyTorch Lightning (Trainer API) to manage the training process.

The primary concern with using Lightning is that it may impose excessive constraints on
new architectures, particularly those intended to use JAX, TensorFlow, Fortran, TeX or
rely heavily on linear algebra operations. While I think it is possible to use a generic PyTorch
Lightning trainer for these architectures, the criticism is acknowledged.

Hence, I propose the following compromise, which will introduces additional maintenance
and coding overhead to differentiate between Lightning and non-Lightning architectures,
not to mention the complexity it adds to our documentation.

The compromise means that architectures not opting for Lightning must still provide
their own Model and Trainer classes that follow to the PyTorch Lightning API. This
consistency ensures we can maintain a singular API call within our train.py.

While non-Lightning architectures are supported, it's strongly encouraged to adopt
Lightning due to its advanced features, such as online logging, distributed multi-GPU
training, etc., which will not be as readily available to non-lightning Trainer classes.

Here's a detailed look at the proposed design changes in train.py:

# Load a model from a restart file
model = architecture.load_from_checkpoint() if restart else architecture.model(**architecture_hypers)

# Create dataloaders using `collate_fn` and previously defined datasets
train_dataloaders = [
    DataLoader(
        dataset=dataset,
        batch_size=options["batch_size"],
        shuffle=True,
        collate_fn=collate_fn,)
        for dataset in train_datasets]

# Initialize the trainer object using `train_hypers` and fit the model
trainer = architecture.Trainer(accelerator, devices, logger=logger, **train_hypers)
trainer.fit(model, train_dataloaders, val_dataloaders, ckpt_path=options["continue_from"])

Custom Model Class

Since we initialize the model before calling the trainer, non-Lightning models must
implement a load_from_checkpoint(checkpoint_path: str) method to continue training
from a saved state. These models aren't required to implement a public
save_from_checkpoint method, as checkpoint saving is managed internally.

Extra Custom Functions

Architectures may define their own export and collate_fn functions. If not provided,
default implementations from our utility folder will be used. The API of these custom
functions which should be located in the model.py of each architecture should look
like:

def export(model: torch.nn.Module, path: Union[str, Path]) -> None:
    ...

def collate_fn(batch: List[NamedTuple]) -> Tuple[List, Dict[str, TensorMap]]:
    ...

DataLoaders instead of Datasets in train.py

To reduce friction among developers, we will not use the preferred
LightningDataModule but use torch DataLoaders directly in train.py.
This allows sufficient flexibility while maintaining architectural consistency.

Custom Trainer Class

Custom Trainer classes must be compatible with the following API, accommodating various devices
and logging requirements:

class Trainer:
    def __init__(self, accelerator: str, devices: int, logger=None, **training_hypers):
        """
        :param accelerator: Type of accelerator ("cpu", "cuda", "mps", etc.)
        :param devices: Number of devices, or -1 for all available devices.
        :param logger: Optional; for online experiment tracking.
        """

    def fit(
        self,
        model,
        train_dataloaders: Iterable[DataLoader],
        val_dataloaders: Iterable[DataLoader],
        ckpt_path: Optional[str]=None,
        ):
        ...

(online) Loggers

The use of the Python logging library will continue for console and local file logging.
This is separate from any online logging mechanisms provided by Lightning.

Lightning team: Please advise if there are additional public API components that
non-Lightning architectures need to implement. Let's discuss any necessary details
before I proceed with the implementation.

@Luthaf
Copy link
Contributor

Luthaf commented May 3, 2024

I don't necessarily love the naming used by Lighning (accelerator where everyone else is using device), but I understand that we need a single API for the train function.

What we could also do is define the Trainer class and API ourself, and then provide a class LighningTrainer which would do the adaptation between our API and Lightning API:

# in `mts/models/train`
class TrainerInterface:
    def __init__(self, mts_models_args):
        ...

    def train(self, mts_models_args):
        ...

# in `mts/models/utils/lightning`
class LightningTrainer(lighning.Trainer):
    def __init__(self, mts_models_args):
        super.__init__(adapt_to_lightning_args(mts_models_args))

    def train(self, mts_models_args):
        return self.fit(adapt_to_lightning_args(mts_models_args))

# in sota_architecture.py
class SOTATrainer(LightningTrainer):
    def __init__(self, mts_models_args):
        super().__init__(mts_models_args)

    def fit(self, lightning_args):
        ...

I see a handful of advantages to this approach:

  • we control the API we use in train.py, and we can make any modification we want to it, reflect it in LightningTrainer and immediately use it
  • we are not married to lightning, and if something nicer comes later, we can easily switch new architectures to it, without breaking the architectures still using lightning
  • we can start with the trainer class being similar to the current function, which should make for a smooth transition

There are some drawbacks as well, the main one is that we need to sit down and define the API (although we will be able to evolve it), and then maintain the corresponding code and documentation.

@abmazitov
Copy link
Contributor

Thank you @PicoCentauri ! This looks very good to me! I see the point in what @Luthaf says, perhaps it would be really better to have this additional compatibility layer.

I'll start working on updating PET on this weekend, perhaps then I'll add some comments depending on what issues I'll face

@PicoCentauri
Copy link
Contributor Author

Yeah might be an option and for super mega SOTA custom architectures the developer has to inherit only from TrainerInterface?

class customTrainer(TrainerInterface):
    def __init__(self, mts_models_args):
        super().__init__(mts_models_args)

    def fit(self, mts_models_args):
        ...

I had the same idea and like it. But, @abmazitov was against inheriting from the lightning Trainer class and wants that we work with the untouched trainer directly. And @frostedoyster was against adding additional wrapper layers because this is had to follow for people.

@Luthaf
Copy link
Contributor

Luthaf commented May 3, 2024

How is one typically using the lightning Trainer? Is it by defining a custom class class MyTrainer(lighning.Trainer) or with some other mechanism?

Regarding additional wrapping layer, this one would have a very low cognitive overhead IMO. We can say to contributors "if you are using a lightning trainer, inherit from mts.models.utils.LightningTrainer instead of lightning.Trainer and that's it". If they are not using lightning trainer already, then my suggestion would basically be the same as the existing train function.

super mega SOTA custom architectures the developer has to inherit only from TrainerInterface?

Not necessarily inherit. Since Python is duck-typed, they just have to follow the same API — this is similar to the ModelInterface.

@abmazitov
Copy link
Contributor

abmazitov commented May 3, 2024

Typically Lightning requires one to define the set of API methods for the model (i.e. lightning.pytorch.LightningModule)(like model.training_step(), model.validation_step(), etc.). As long as it is done, the idea is to use the lightning.pytorch.Trainer directly and pass the model to its fit() method

import lightning.pytorch as pl

class Model(pl.LightningModule):
    def __init__(self, *args, **kwargs):
         pass
    def training_step(self, ...):
        pass
    def validation_step(self, ...):
         pass

model = Model(*args, **kwargs)
trainer = pl.Trainer(accelerator=..., devices=..., logger=..., callbacks=...)
trainer.fit(model, train_dataloaders=..., val_dataloaders=...)

@Luthaf
Copy link
Contributor

Luthaf commented May 3, 2024

Ahh, ok. Then with my proposal that would become something like this:

# mts/models/utils.py
import lightning.pytorch as pl

class LightningTrainer:
    def __init__(self, lightning_module, mts_train_init_args):
        self.lightning_module = lightning_module
        self.trainer = pl.Trainer(accelerator=..., devices=..., logger=..., callbacks=...)
    
    def train(self, mts_train_args):
        # handle the adaptation between model/lightning_module
        return self.trainer.fit(self.lightning_module)


# sota_arch.py
import lightning.pytorch as pl
from mts.models.utils import LightningTrainer

class TrainingLoop(pl.LightningModule):
    def __init__(self, *args, **kwargs):
         pass

    def training_step(self, ...):
        pass
    
    def validation_step(self, ...):
         pass

class Trainer(LightningTrainer):
    def __init__(self, mts_train_init_args):
        lightning_module = TrainingLoop()
        super().__init__(lightning_module, mts_train_init_args)

It's not very clear how we can handle checkpointing/intial model creation with this API, but I don't see a fundamental issue.

@PicoCentauri
Copy link
Contributor Author

Regarding additional wrapping layer, this one would have a very low cognitive overhead IMO. We can say to contributors "if you are using a lightning trainer, inherit from mts.models.utils.LightningTrainer instead of lightning.Trainer and that's it". If they are not using lightning trainer already, then my suggestion would basically be the same as the existing train function.

I think a contributor hasn't do anything. I would use the same strategy as we do for export and collert_fn. If there is no Trainer class defined we use ours which is based on lightning.

Of course a contributor can also inherit from "our" Trainer class to do their SOTA optimization stuff. If they do we recognize that there is a Trainer class defined and we use this one - of course with the same as our API.

Or do we really want to support several base classes?

@PicoCentauri
Copy link
Contributor Author

It's not very clear how we can handle checkpointing/intial model creation with this API, but I don't see a fundamental issue.

Initial model creation is done via the load_from_checkpoint that the model class has to provide.

for checkpointing:

I think if you are living the in the lightning world, the lightning trainer will take care and you only have to provide a save_to_checkpoint function in the model class. But I also don't know this in details. Maybe @abmazitov can provide a nice example as above.

A custom trainer has to implement checkpoint saving by themselves.

@Luthaf
Copy link
Contributor

Luthaf commented May 3, 2024

Of course a contributor can also inherit from "our" Trainer class to do their SOTA optimization stuff. If they do we recognize that there is a Trainer class defined and we use this one - of course with the same as our API.

There are a couple of different use cases here:

  1. user already has a training script, without lightning that they want to keep using
  2. user already has a training script, with lightning that they want to keep using
  3. user already has a training script, with some other training framework that they want to keep using
  4. user has no training script, or don't really care

In my proposal, case (1) would be doing a custom implementation of TrainerInterface, using their code inside the train function. Case (2) would use LightningTrainer, which manages the mts-models <=> lightning interface. Case (3) would be the same as case (1) initially, but if enough people want to use it we can introduce something similar to LightningTrainer for this other framework. Case (4) we can maybe provide some default training implementation (although I'm not sure we actually want to do that …), whether it is based on LightningTrainer or some custom code does not matter much for me.


Or do we really want to support several base classes?

There is one interface (not used through inheritance) and one implementation of this interface (LightningTrainer) which is used through inheritance. Contributors can then add more ad-hoc implementations of the interface for their own code if they care and don't want to use lightning.

@PicoCentauri
Copy link
Contributor Author

Okay I think I see the points and we agree. Your basic point is to have one class in between the lightning trainer and our interface to theirs.

I also would like to add some checks in this wrapper class. Supported precision, GPU selection, better checking for model and training capabilities class for the datases. I.e if the user wants to train on target properties that the architecture doesn't provide. Or maybe if the architecture can only train on one dataset but the user wants more etc...

Of course one could do everything in single functions but this very unreadable currently or every architecture has to implement these changes again see for example as in PET.

https://github.com/lab-cosmo/metatensor-models/blob/2fa2ac46a8a9960d6657abd88c37abafb3db84ec/src/metatensor/models/experimental/pet/train.py#L30-L38

This is very annoying...

@frostedoyster
Copy link
Collaborator

I like having our own Trainer class and then providing utilities for interfacing it with Lightning

@Luthaf
Copy link
Contributor

Luthaf commented May 7, 2024

After a discussion with @PicoCentauri, we arrived to the following

# interface definition
class ModelInterface:
    def __init__(self, architecture_hypers):
        pass

    def save_checkpoint(self, path):
        pass

    @classmethod
    def load_checkpoint(cls, path) -> "ModelInterface":
        pass


class TrainerInterface:
    def __init__(self, train_hypers):
        pass

    def train(
        self,
        model: ModelInterface,
        devices: List[torch.device],
        train_datasets: List[torch.Dataset],
        validation_datasets: List[torch.Dataset],
        datasets_metadata: DatasetMetadata,
        checkpoints_dir: str,
    ):
        pass

    def export(
        self,
        model: ModelInterface,
        dataset_metadata: DatasetMetadata,
    ) -> MetatensorAtomisticModel:
        pass
# interface usage in the CLI

__MODEL_CLASS__, __TRAINER_CLASS__ = import(...)

hypers = {}
if "continue_from_checkpoint":
    model = __MODEL_CLASS__.from_checkpoint("path")
else:
    model = __MODEL_CLASS__(hypers["architecture"])

trainer = __TRAINER_CLASS__(hypers["train"])

trainer.train(
    model=model,
    devices=[],
    train_datasets=[],
    validation_datasets=[],
    datasets_metadata=datasets_metadata,
    checkpoints_dir="path",
)

model.save_checkpoint("final.ckpt")

exported = trainer.export(model, datasets_metadata)
exported.save("path", collect_extensions="extensions-dir/")
# Lightning utilities
class LightningTrainer:
    def __init__(self, trainer_options=None):
        if trainer_options is None:
            self._trainer_options = {}
        else:
            self._trainer_options = trainer_options

        if "accelerator" in self._trainer_options:
            raise Exception()

        if "devices" in self._trainer_options:
            raise Exception()

    def train(
        self,
        model: ModelInterface,
        datasets_metadata: DatasetMetadata,
        train_datasets: List[torch.Dataset],
        validation_datasets: List[torch.Dataset],
        devices: List[torch.device],
        checkpoints_dir: str,
    ):
        accelerator, devices = self._devices_to_lightning(devices)

        # TODO: init loggers
        self._trainer = pl.Trainer(
            accelerator=accelerator,
            devices=devices,
            **self._trainer_options,
        )

        self.check_dataset(datasets_metadata)

        # TODO: should we allow custom dataloaders?
        train_dataloaders = [DataLoader(d) for d in train_datasets]
        val_dataloaders = [DataLoader(d) for d in validation_datasets]

        self._trainer.fit(
            self.module,
            train_dataloaders=train_dataloaders,
            val_dataloaders=val_dataloaders,
            ckpt_path=checkpoints_dir,
        )

    def export(
        self,
        model: ModelInterface,
        dataset_metadata: DatasetMetadata,
    ) -> MetatensorAtomisticModel:
        raise NotImplementedError()

    def check_dataset(self, datasets_metadata):
        """To be overridden by the actual trainer"""

    def _devices_to_lightning():
        """TODO"""

This code would be used as follow by contributors:

# lightning-based model

class CustomSOTAModel(pl.LightningModule):
    def __init__(self, hypers):
        """TODO"""

    def training_step(self, batch, batch_idx):
        """TODO"""

    def configure_optimizers(self):
        """TODO"""

    def save_checkpoint(self, path):
        """TODO"""

    @classmethod
    def load_checkpoint(cls, path) -> "CustomSOTAModel":
        """TODO"""


class Trainer(LightningTrainer):
    def __init__(self, hypers, devices):
        super().__init__(
            module=CustomSOTAModel(hypers),
            devices=devices,
            trainer_options={
                "logger": lightning_logger_create(hypers["architecture"]),
            },
        )

    def check_dataset(self, metadata: DatasetMetadata):
        """TODO"""

    def export(
        self,
        model: CustomSOTAModel,
        dataset_metadata: DatasetMetadata,
    ) -> MetatensorAtomisticModel:
        """TODO"""


__MODEL_CLASS__ = CustomSOTAModel
__TRAINER_CLASS__ = Trainer
# non lightning-based model

class NumpySOTAModel:
    def __init__(self, hypers):
        """TODO"""

    def save_checkpoint(self, path):
        """TODO"""

    @classmethod
    def load_checkpoint(cls, path) -> "NumpySOTAModel":
        """TODO"""


class Trainer:
    def __init__(self, hypers):
        """TODO"""

    def train(
        self,
        model: NumpySOTAModel,
        devices: List[torch.device],
        train_datasets: List[torch.Dataset],
        validation_datasets: List[torch.Dataset],
        datasets_metadata: DatasetMetadata,
        checkpoints_dir: str,
    ):
        """TODO"""

    def export(
        self,
        model: NumpySOTAModel,
        dataset_metadata: DatasetMetadata,
    ) -> MetatensorAtomisticModel:
        """TODO"""


__MODEL_CLASS__ = NumpySOTAModel
__TRAINER_CLASS__ = Trainer

Any thoughts? We could also provide a LightningModel class that would implement save_checkpoint/load_checkpoint in terms of the lightning equivalents.

@PicoCentauri
Copy link
Contributor Author

As a minor addition we decided that we don't perform any checks on if the datasets fit the what the architecture can handle. The arch developer has to do this.

Also, if we go with this designe I suggest that we first wrap as non-lightning based models and the we can adjust them later.

@frostedoyster
Copy link
Collaborator

class ModelInterface:
    def __init__(self, architecture_hypers):
        pass

This is not enough IMO. The model needs information about the targets to be initialized correctly. We will need at least dataset_info in this constructor, as well as possibly the training_hypers (although I hope we can do without those)

@frostedoyster
Copy link
Collaborator

Another somewhat related comment: the DatasetInfo class is difficult to design in a way that applies to all types of models. Since the metadata is naturally there in the TensorMaps of the datasets, I think that a good choice would be to include a data-less TensorMap (meta device), as well as Systems, inside DatasetInfo. That would naturally include all the metadata we need (perhaps to be complemented with something like TargetInfo for the units). Essentially this would be like including a dummy sample from the Dataset as part of the DatasetInfo

@PicoCentauri
Copy link
Contributor Author

What do you suggest for the interface? Keep it like we have it currently like

class ModelInterface:
    def __init__(self, capabilities: ModelCapabilities, architecture_hypers: Dict):
        ...

But then we have to create the ModelCapabilities inside the train function.

@Luthaf
Copy link
Contributor

Luthaf commented May 17, 2024

Capabilities should not be created by anyone else than the trainer. But we can add TargetInfo to the model __init__!

I would not change TargetInfo/DatasetInfo for now, we can explore this once we have a Trainer class!

@Luthaf
Copy link
Contributor

Luthaf commented May 17, 2024

After another round of discussion with @PicoCentauri and @frostedoyster, we modified the interface a bit, the current one is now:

class TargetInfo:
    quantity: str
    unit: str = ""
    per_atom: bool = False
    gradients: List[str] = field(default_factory=list)


class DatasetInfo:
    length_unit: str
    species: List[int] # add this to be able to restart training with a new dataset
    targets: Dict[str, TargetInfo]



class ModelInterface:
    def __init__(self, architecture_hypers, dataset_info: DatasetInfo):
        pass

    def save_checkpoint(self, path):
        pass

    @classmethod
    def load_checkpoint(cls, path) -> "ModelInterface":
        pass

    def restart(self, dataset: DatasetInfo) -> "ModelInterface":
        """
        This function is called whenever training restarts, with the same or a different dataset.

        It enables transfer learning (changing the targets), and fine tuning (same targets, different dataset)
        """
        pass

    def export(self) -> MetatensorAtomisticModel:
        pass


class TrainerInterface:
    def __init__(self, train_hypers):
        pass

    def train(
        self,
        model: ModelInterface,
        devices: List[torch.device],
        train_datasets: List[torch.Dataset],
        validation_datasets: List[torch.Dataset],
        checkpoints_dir: str,
    ):
        pass
__MODEL_CLASS__, __TRAINER_CLASS__ = import(...)

hypers = {...}
dataset_info = ...

if "continue_from_checkpoint":
    model = __MODEL_CLASS__.from_checkpoint("path")
    model = model.restart(dataset_info)
else:
    model = __MODEL_CLASS__(hypers["architecture"], dataset_info)

trainer = __TRAINER_CLASS__(hypers["train"])

trainer.train(
    model=model,
    devices=[...],
    train_datasets=[...],
    validation_datasets=[...],
    checkpoints_dir="path",
)

model.save_checkpoint("final.ckpt")

mts_atomistic_model = model.export()
mts_atomistic_model.export("path", collect_extensions="extensions-dir/")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants