Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple model checkpointing #37

Merged
merged 4 commits into from
Nov 7, 2022
Merged

Conversation

meffmadd
Copy link
Contributor

@meffmadd meffmadd commented Nov 4, 2022

No description provided.

@nmichlo
Copy link
Owner

nmichlo commented Nov 4, 2022

Hi @meffmadd, thank you so much for this contribution. It looks great! 😁🎉

Before we merge, please may you update the checkpointing in the configs as some tests are currently failing and for documentation purposes?

On that note, we might want to just add the checkpointing to the pytorch lightning train step, using the same hook. Then possibly at the end, load the checkpoint after everything to make sure it works.

  • @pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], _TEST_FRAMEWORKS)
    def test_frameworks(Framework, cfg_kwargs, Data):
    DataSampler = {
    1: GroundTruthSingleSampler,
    2: GroundTruthPairSampler,
    3: GroundTruthTripleSampler,
    }[Framework.REQUIRED_OBS]
    data = XYObjectData() if (Data is None) else Data()
    dataset = DisentDataset(data, DataSampler(), transform=ToImgTensorF32())
    dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=0)
    framework = Framework(
    model=AutoEncoder(
    encoder=EncoderLinear(x_shape=data.x_shape, z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1),
    decoder=DecoderLinear(x_shape=data.x_shape, z_size=6),
    ),
    cfg=Framework.cfg(**cfg_kwargs)
    )
    # test pickling before training
    pickle.dumps(framework)
    # train!
    trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=True)
    trainer.fit(framework, dataloader)
    # test pickling after training, something may have changed!
    pickle.dumps(framework)

@meffmadd
Copy link
Contributor Author

meffmadd commented Nov 7, 2022

Hi, I couldn't work on it this weekend but will start now and fix the configs so that the tests work again. I will also add a test case that tests the behavior.

Added save_checkpoint to experiment configs
Added tests for checkpointing
@codecov
Copy link

codecov bot commented Nov 7, 2022

Codecov Report

Base: 70.01% // Head: 70.04% // Increases project coverage by +0.02% 🎉

Coverage data is based on head (28d0e95) compared to base (8dee583).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #37      +/-   ##
==========================================
+ Coverage   70.01%   70.04%   +0.02%     
==========================================
  Files         135      135              
  Lines        7531     7538       +7     
==========================================
+ Hits         5273     5280       +7     
  Misses       2258     2258              
Impacted Files Coverage Δ
disent/frameworks/_ae_mixin.py 90.00% <100.00%> (+1.32%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

disent/frameworks/vae/_unsupervised__vae.py Outdated Show resolved Hide resolved
experiment/run.py Show resolved Hide resolved
@nmichlo
Copy link
Owner

nmichlo commented Nov 7, 2022

Great work! Thank you so much for making these changes!

@nmichlo nmichlo merged commit 950ba81 into nmichlo:main Nov 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants