Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError when max_checkpoints_to_keep=0 in TorchModel's #3810

Closed
arunppsg opened this issue Jan 30, 2024 · 6 comments · Fixed by #3820
Closed

IndexError when max_checkpoints_to_keep=0 in TorchModel's #3810

arunppsg opened this issue Jan 30, 2024 · 6 comments · Fixed by #3820

Comments

@arunppsg
Copy link
Contributor

Steps to reproduce:

from deepchem.models.torch_models import GCNModel
from deepchem.molnet import load_delaney
from deepchem.feat import MolGraphConvFeaturizer

featurizer = MolGraphConvFeaturizer()
dataset = load_delaney(featurizer=featurizer)

tasks, dataset, transformers = dataset

n_tasks = len(tasks)
model = GCNModel(mode='regression',
                 n_tasks=n_tasks,
                 number_atom_features=30,
                 batch_size=10,
                 learning_rate=0.001,
                device='cpu')

train = dataset[0]

model.fit(train, nb_epoch=10, max_checkpoints_to_keep=0, checkpoint_interval=1)

The last line raises the following error:

File ~/deepchem/deepchem/models/torch_models/torch_model.py:338, in TorchModel.fit(self, dataset, nb_epoch, max_checkpoints_to_keep, checkpoint_interval, deterministic, restore, variables, loss, callbacks, all_losses)
    289 def fit(self,
    290         dataset: Dataset,
    291         nb_epoch: int = 10,
   (...)
    298         callbacks: Union[Callable, List[Callable]] = [],
    299         all_losses: Optional[List[float]] = None) -> float:
    300     """Train this model on a dataset.
    301 
    302     Parameters
   (...)
    336     The average loss over the most recent checkpoint interval
    337     """
--> 338     return self.fit_generator(
    339         self.default_generator(dataset,
    340                                epochs=nb_epoch,
    341                                deterministic=deterministic),
    342         max_checkpoints_to_keep, checkpoint_interval, restore, variables,
    343         loss, callbacks, all_losses)

File ~/deepchem/deepchem/models/torch_models/torch_model.py:460, in TorchModel.fit_generator(self, generator, max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss, callbacks, all_losses)
    457     averaged_batches = 0
    459 if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
--> 460     self.save_checkpoint(max_checkpoints_to_keep)
    461 for c in callbacks:
    462     c(self, current_step)

File ~/deepchem/deepchem/models/torch_models/torch_model.py:1022, in TorchModel.save_checkpoint(self, max_checkpoints_to_keep, model_dir)
   1016 # Rename and delete older files.
   1018 paths = [
   1019     os.path.join(model_dir, 'checkpoint%d.pt' % (i + 1))
   1020     for i in range(max_checkpoints_to_keep)
   1021 ]
-> 1022 if os.path.exists(paths[-1]):
   1023     os.remove(paths[-1])
   1024 for i in reversed(range(max_checkpoints_to_keep - 1)):

IndexError: list index out of range

Setting max_checkpoints_to_keep=0 helps to avoid time spend in disk IO in development of large models.

@quincylin1
Copy link
Contributor

Hello @arunppsg thank you so much for raising this. I'm looking into this issue

@quincylin1
Copy link
Contributor

Since save_checkpoint() keeps the max_checkpoints_to_keep number of checkpoints, I think it makes sense to just return the function without doing anything when max_checkpoints_to_keep = 0 as setting it to zero means you don't want to keep any checkpoint.

What do you think?

@quincylin1
Copy link
Contributor

Something like this:

def save_checkpoint(self,
                        max_checkpoints_to_keep: int = 5,
                        model_dir: Optional[str] = None) -> None:
        """Save a checkpoint to disk.

        Usually you do not need to call this method, since fit() saves checkpoints
        automatically.  If you have disabled automatic checkpointing during fitting,
        this can be called to manually write checkpoints.

        Parameters
        ----------
        max_checkpoints_to_keep: int
            the maximum number of checkpoints to keep.  Older checkpoints are discarded.
        model_dir: str, default None
            Model directory to save checkpoint to. If None, revert to self.model_dir
        """
        if max_checkpoints_to_keep == 0:
            return
        self._ensure_built()
        if model_dir is None:
            model_dir = self.model_dir
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

@quincylin1
Copy link
Contributor

quincylin1 commented Feb 6, 2024

Or do you think we should disable calling save_checkpoint() when max_checkpoints_to_keep = 0? Since setting it to zero means you dont want to save any checkpoint, it makes more sense not to call save_checkpoint() at all.

In class TorchModel, fit_generator():

if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1 and max_checkpoints_to_keep > 0:
    self.save_checkpoint(max_checkpoints_to_keep)

@arunppsg
Copy link
Contributor Author

arunppsg commented Feb 6, 2024

Since save_checkpoint() keeps the max_checkpoints_to_keep number of checkpoints, I think it makes sense to just return the function without doing anything when max_checkpoints_to_keep = 0 as setting it to zero means you don't want to keep any checkpoint.

This sounds like a good idea.

@quincylin1
Copy link
Contributor

Since save_checkpoint() keeps the max_checkpoints_to_keep number of checkpoints, I think it makes sense to just return the function without doing anything when max_checkpoints_to_keep = 0 as setting it to zero means you don't want to keep any checkpoint.

This sounds like a good idea.

Cool! Let me create a pr!

quincylin1 added a commit to quincylin1/deepchem that referenced this issue Feb 6, 2024
@quincylin1 quincylin1 mentioned this issue Feb 6, 2024
15 tasks
rbharath added a commit that referenced this issue Feb 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants