-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace train
Function with Trainer
Class
#174
Comments
I don't necessarily love the naming used by Lighning ( What we could also do is define the # in `mts/models/train`
class TrainerInterface:
def __init__(self, mts_models_args):
...
def train(self, mts_models_args):
...
# in `mts/models/utils/lightning`
class LightningTrainer(lighning.Trainer):
def __init__(self, mts_models_args):
super.__init__(adapt_to_lightning_args(mts_models_args))
def train(self, mts_models_args):
return self.fit(adapt_to_lightning_args(mts_models_args))
# in sota_architecture.py
class SOTATrainer(LightningTrainer):
def __init__(self, mts_models_args):
super().__init__(mts_models_args)
def fit(self, lightning_args):
... I see a handful of advantages to this approach:
There are some drawbacks as well, the main one is that we need to sit down and define the API (although we will be able to evolve it), and then maintain the corresponding code and documentation. |
Thank you @PicoCentauri ! This looks very good to me! I see the point in what @Luthaf says, perhaps it would be really better to have this additional compatibility layer. I'll start working on updating PET on this weekend, perhaps then I'll add some comments depending on what issues I'll face |
Yeah might be an option and for super mega SOTA custom architectures the developer has to inherit only from class customTrainer(TrainerInterface):
def __init__(self, mts_models_args):
super().__init__(mts_models_args)
def fit(self, mts_models_args):
... I had the same idea and like it. But, @abmazitov was against inheriting from the lightning Trainer class and wants that we work with the untouched trainer directly. And @frostedoyster was against adding additional wrapper layers because this is had to follow for people. |
How is one typically using the lightning Trainer? Is it by defining a custom class Regarding additional wrapping layer, this one would have a very low cognitive overhead IMO. We can say to contributors "if you are using a lightning trainer, inherit from
Not necessarily inherit. Since Python is duck-typed, they just have to follow the same API — this is similar to the ModelInterface. |
Typically Lightning requires one to define the set of API methods for the model (i.e. import lightning.pytorch as pl
class Model(pl.LightningModule):
def __init__(self, *args, **kwargs):
pass
def training_step(self, ...):
pass
def validation_step(self, ...):
pass
model = Model(*args, **kwargs)
trainer = pl.Trainer(accelerator=..., devices=..., logger=..., callbacks=...)
trainer.fit(model, train_dataloaders=..., val_dataloaders=...) |
Ahh, ok. Then with my proposal that would become something like this: # mts/models/utils.py
import lightning.pytorch as pl
class LightningTrainer:
def __init__(self, lightning_module, mts_train_init_args):
self.lightning_module = lightning_module
self.trainer = pl.Trainer(accelerator=..., devices=..., logger=..., callbacks=...)
def train(self, mts_train_args):
# handle the adaptation between model/lightning_module
return self.trainer.fit(self.lightning_module)
# sota_arch.py
import lightning.pytorch as pl
from mts.models.utils import LightningTrainer
class TrainingLoop(pl.LightningModule):
def __init__(self, *args, **kwargs):
pass
def training_step(self, ...):
pass
def validation_step(self, ...):
pass
class Trainer(LightningTrainer):
def __init__(self, mts_train_init_args):
lightning_module = TrainingLoop()
super().__init__(lightning_module, mts_train_init_args) It's not very clear how we can handle checkpointing/intial model creation with this API, but I don't see a fundamental issue. |
I think a contributor hasn't do anything. I would use the same strategy as we do for Of course a contributor can also inherit from "our" Or do we really want to support several base classes? |
Initial model creation is done via the for checkpointing: I think if you are living the in the lightning world, the lightning trainer will take care and you only have to provide a A custom trainer has to implement checkpoint saving by themselves. |
There are a couple of different use cases here:
In my proposal, case (1) would be doing a custom implementation of
There is one interface (not used through inheritance) and one implementation of this interface ( |
Okay I think I see the points and we agree. Your basic point is to have one class in between the lightning trainer and our interface to theirs. I also would like to add some checks in this wrapper class. Supported precision, GPU selection, better checking for model and training capabilities class for the datases. I.e if the user wants to train on target properties that the architecture doesn't provide. Or maybe if the architecture can only train on one dataset but the user wants more etc... Of course one could do everything in single functions but this very unreadable currently or every architecture has to implement these changes again see for example as in PET. This is very annoying... |
I like having our own Trainer class and then providing utilities for interfacing it with Lightning |
After a discussion with @PicoCentauri, we arrived to the following # interface definition
class ModelInterface:
def __init__(self, architecture_hypers):
pass
def save_checkpoint(self, path):
pass
@classmethod
def load_checkpoint(cls, path) -> "ModelInterface":
pass
class TrainerInterface:
def __init__(self, train_hypers):
pass
def train(
self,
model: ModelInterface,
devices: List[torch.device],
train_datasets: List[torch.Dataset],
validation_datasets: List[torch.Dataset],
datasets_metadata: DatasetMetadata,
checkpoints_dir: str,
):
pass
def export(
self,
model: ModelInterface,
dataset_metadata: DatasetMetadata,
) -> MetatensorAtomisticModel:
pass # interface usage in the CLI
__MODEL_CLASS__, __TRAINER_CLASS__ = import(...)
hypers = {}
if "continue_from_checkpoint":
model = __MODEL_CLASS__.from_checkpoint("path")
else:
model = __MODEL_CLASS__(hypers["architecture"])
trainer = __TRAINER_CLASS__(hypers["train"])
trainer.train(
model=model,
devices=[],
train_datasets=[],
validation_datasets=[],
datasets_metadata=datasets_metadata,
checkpoints_dir="path",
)
model.save_checkpoint("final.ckpt")
exported = trainer.export(model, datasets_metadata)
exported.save("path", collect_extensions="extensions-dir/") # Lightning utilities
class LightningTrainer:
def __init__(self, trainer_options=None):
if trainer_options is None:
self._trainer_options = {}
else:
self._trainer_options = trainer_options
if "accelerator" in self._trainer_options:
raise Exception()
if "devices" in self._trainer_options:
raise Exception()
def train(
self,
model: ModelInterface,
datasets_metadata: DatasetMetadata,
train_datasets: List[torch.Dataset],
validation_datasets: List[torch.Dataset],
devices: List[torch.device],
checkpoints_dir: str,
):
accelerator, devices = self._devices_to_lightning(devices)
# TODO: init loggers
self._trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
**self._trainer_options,
)
self.check_dataset(datasets_metadata)
# TODO: should we allow custom dataloaders?
train_dataloaders = [DataLoader(d) for d in train_datasets]
val_dataloaders = [DataLoader(d) for d in validation_datasets]
self._trainer.fit(
self.module,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
ckpt_path=checkpoints_dir,
)
def export(
self,
model: ModelInterface,
dataset_metadata: DatasetMetadata,
) -> MetatensorAtomisticModel:
raise NotImplementedError()
def check_dataset(self, datasets_metadata):
"""To be overridden by the actual trainer"""
def _devices_to_lightning():
"""TODO""" This code would be used as follow by contributors: # lightning-based model
class CustomSOTAModel(pl.LightningModule):
def __init__(self, hypers):
"""TODO"""
def training_step(self, batch, batch_idx):
"""TODO"""
def configure_optimizers(self):
"""TODO"""
def save_checkpoint(self, path):
"""TODO"""
@classmethod
def load_checkpoint(cls, path) -> "CustomSOTAModel":
"""TODO"""
class Trainer(LightningTrainer):
def __init__(self, hypers, devices):
super().__init__(
module=CustomSOTAModel(hypers),
devices=devices,
trainer_options={
"logger": lightning_logger_create(hypers["architecture"]),
},
)
def check_dataset(self, metadata: DatasetMetadata):
"""TODO"""
def export(
self,
model: CustomSOTAModel,
dataset_metadata: DatasetMetadata,
) -> MetatensorAtomisticModel:
"""TODO"""
__MODEL_CLASS__ = CustomSOTAModel
__TRAINER_CLASS__ = Trainer # non lightning-based model
class NumpySOTAModel:
def __init__(self, hypers):
"""TODO"""
def save_checkpoint(self, path):
"""TODO"""
@classmethod
def load_checkpoint(cls, path) -> "NumpySOTAModel":
"""TODO"""
class Trainer:
def __init__(self, hypers):
"""TODO"""
def train(
self,
model: NumpySOTAModel,
devices: List[torch.device],
train_datasets: List[torch.Dataset],
validation_datasets: List[torch.Dataset],
datasets_metadata: DatasetMetadata,
checkpoints_dir: str,
):
"""TODO"""
def export(
self,
model: NumpySOTAModel,
dataset_metadata: DatasetMetadata,
) -> MetatensorAtomisticModel:
"""TODO"""
__MODEL_CLASS__ = NumpySOTAModel
__TRAINER_CLASS__ = Trainer Any thoughts? We could also provide a |
As a minor addition we decided that we don't perform any checks on if the datasets fit the what the architecture can handle. The arch developer has to do this. Also, if we go with this designe I suggest that we first wrap as non-lightning based models and the we can adjust them later. |
This is not enough IMO. The model needs information about the targets to be initialized correctly. We will need at least |
Another somewhat related comment: the |
What do you suggest for the interface? Keep it like we have it currently like class ModelInterface:
def __init__(self, capabilities: ModelCapabilities, architecture_hypers: Dict):
... But then we have to create the ModelCapabilities inside the train function. |
Capabilities should not be created by anyone else than the trainer. But we can add I would not change |
After another round of discussion with @PicoCentauri and @frostedoyster, we modified the interface a bit, the current one is now: class TargetInfo:
quantity: str
unit: str = ""
per_atom: bool = False
gradients: List[str] = field(default_factory=list)
class DatasetInfo:
length_unit: str
species: List[int] # add this to be able to restart training with a new dataset
targets: Dict[str, TargetInfo]
class ModelInterface:
def __init__(self, architecture_hypers, dataset_info: DatasetInfo):
pass
def save_checkpoint(self, path):
pass
@classmethod
def load_checkpoint(cls, path) -> "ModelInterface":
pass
def restart(self, dataset: DatasetInfo) -> "ModelInterface":
"""
This function is called whenever training restarts, with the same or a different dataset.
It enables transfer learning (changing the targets), and fine tuning (same targets, different dataset)
"""
pass
def export(self) -> MetatensorAtomisticModel:
pass
class TrainerInterface:
def __init__(self, train_hypers):
pass
def train(
self,
model: ModelInterface,
devices: List[torch.device],
train_datasets: List[torch.Dataset],
validation_datasets: List[torch.Dataset],
checkpoints_dir: str,
):
pass __MODEL_CLASS__, __TRAINER_CLASS__ = import(...)
hypers = {...}
dataset_info = ...
if "continue_from_checkpoint":
model = __MODEL_CLASS__.from_checkpoint("path")
model = model.restart(dataset_info)
else:
model = __MODEL_CLASS__(hypers["architecture"], dataset_info)
trainer = __TRAINER_CLASS__(hypers["train"])
trainer.train(
model=model,
devices=[...],
train_datasets=[...],
validation_datasets=[...],
checkpoints_dir="path",
)
model.save_checkpoint("final.ckpt")
mts_atomistic_model = model.export()
mts_atomistic_model.export("path", collect_extensions="extensions-dir/") |
We are considering replacing the
train
function for each model architecture with aTrainer
class. This change aims not only to improve our code structure but also to pavethe way for integrating PyTorch Lightning (Trainer API) to manage the training process.
The primary concern with using Lightning is that it may impose excessive constraints on
new architectures, particularly those intended to use JAX, TensorFlow, Fortran, TeX or
rely heavily on linear algebra operations. While I think it is possible to use a generic PyTorch
Lightning trainer for these architectures, the criticism is acknowledged.
Hence, I propose the following compromise, which will introduces additional maintenance
and coding overhead to differentiate between Lightning and non-Lightning architectures,
not to mention the complexity it adds to our documentation.
The compromise means that architectures not opting for Lightning must still provide
their own
Model
andTrainer
classes that follow to the PyTorch Lightning API. Thisconsistency ensures we can maintain a singular API call within our
train.py
.While non-Lightning architectures are supported, it's strongly encouraged to adopt
Lightning due to its advanced features, such as online logging, distributed multi-GPU
training, etc., which will not be as readily available to non-lightning
Trainer
classes.Here's a detailed look at the proposed design changes in
train.py
:Custom Model Class
Since we initialize the model before calling the trainer, non-Lightning models must
implement a
load_from_checkpoint(checkpoint_path: str)
method to continue trainingfrom a saved state. These models aren't required to implement a public
save_from_checkpoint
method, as checkpoint saving is managed internally.Extra Custom Functions
Architectures may define their own
export
andcollate_fn
functions. If not provided,default implementations from our
utility
folder will be used. The API of these customfunctions which should be located in the
model.py
of each architecture should looklike:
DataLoaders instead of Datasets in train.py
To reduce friction among developers, we will not use the preferred
LightningDataModule
but use torchDataLoaders
directly intrain.py
.This allows sufficient flexibility while maintaining architectural consistency.
Custom
Trainer
ClassCustom Trainer classes must be compatible with the following API, accommodating various devices
and logging requirements:
(online)
Loggers
The use of the Python logging library will continue for console and local file logging.
This is separate from any online logging mechanisms provided by Lightning.
Lightning team: Please advise if there are additional public API components that
non-Lightning architectures need to implement. Let's discuss any necessary details
before I proceed with the implementation.
The text was updated successfully, but these errors were encountered: