Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CLI support all relevant model types #13

Closed
kmaziarz opened this issue Apr 29, 2022 · 9 comments · Fixed by #24
Closed

Make CLI support all relevant model types #13

kmaziarz opened this issue Apr 29, 2022 · 9 comments · Fixed by #24
Assignees
Labels
enhancement New feature or request

Comments

@kmaziarz
Copy link
Collaborator

Currently all CLI entry points only work for MoLeRVae as they use VaeWrapper, while e.g. molecule_generation sample could also support MoLeRGenerator. While generalizing things, we may also want to rethink the choice to do model type discovery based on filenames (@sarahnlewis may have thoughts on this).

@kmaziarz kmaziarz added the enhancement New feature or request label Apr 29, 2022
@anamika-yadav99
Copy link
Contributor

Hi @kmaziarz I'm an undergrad student from India. I do have experience working on similiar projects and CLI. Can I work on this issue?

@kmaziarz
Copy link
Collaborator Author

Yes, sure! Sorry for a slow response: the best way of addressing this issue isn't very clear, and so I had to give it a bit of thought. One thing to note is that the underlying model loading utility (load_vae_model_and_dataset) already works with all model types, and the missing functionality is that we need to choose the right wrapper class (either VaeWrapper or GeneratorWrapper). Ideally, the outcome would be that apart from being able to do

with VaeWrapper(model_dir, **model_kwargs) as model:
    (...)

which is what we do currently (see e.g. cli/encode.py), we could also do

with load_model_from_directory(model_dir, **model_kwargs) as model:
    (...)

and load_model_from_directory would return the right wrapper class. Then, for scripts that can work with both VAE-style and generator-style models (e.g. cli/sample.py), we'd use this new function to load them.

To return the right wrapper class, we would just need to select one and then pass all the arguments through. We currently have a bunch of filename matching in wrapper.py, and we could go deeper in that direction, but ultimately this feels unreliable. Instead, maybe we could take the following steps:

  • Get rid of the _is_moler_model_filename function that differs between the two wrapper classes, so that ModelWrapper._get_model_file would grab all the files that end with *.pkl as potential model files (and then assert there's exactly one, as is currently done).
  • Create a helper get_model_class similar to get_model_parameters in model_utils.py.
  • Connect the two things above to infer which wrapper to construct: first run ModelWrapper._get_model_file to get the model path, and then get_model_class to get the class (e.g. MoLeRVae); based on this, we'd return either VaeWrapper or GeneratorWrapper.

What do you think? If it's all too confusing I'm also happy to take a stab at this myself; arguably addressing this issue requires more fiddling with internals than would initially seem... Also pinging @sarahnlewis in case she has any comments.

@sarahnlewis
Copy link
Contributor

@kmaziarz is your suggestion that the model type be found in the contents of the same .pkl file that get_model_parameters reads, or that we save a separate file e.g. model_type.txt with each trained model? I think I would prefer the latter.

@kmaziarz
Copy link
Collaborator Author

@kmaziarz is your suggestion that the model type be found in the contents of the same .pkl file that get_model_parameters reads, or that we save a separate file e.g. model_type.txt with each trained model?

The model class is already being saved in the *.pkl file (which is a dict, containing not only the model class and weights but also various other hyperparams), and load_vae_model_and_dataset (the lower level model loading utility used for all model types) use the class it reads to load the model. So my proposal is to just make use of that (which has an advantage of being compatible with old checkpoints out-of-the-box).

@sarahnlewis
Copy link
Contributor

OK, sounds good.

@anamika-yadav99
Copy link
Contributor

Thanks for the head start @kmaziarz and @sarahnlewis . I'll get started with the task.

@anamika-yadav99
Copy link
Contributor

Hi @kmaziarz so far I have been able to extract the model class from the .pkl file and defined a method which returns which model wrapper class to use based on the model class i.e. Vaewrapper for MoLeRVae and GeneratorWrapper for MoLeRGenerator class. Now the confusion that I have is that GeneratorWrapper doesn't have encode method. From the sample function already defined, it looks like GeneratorWrapper doesn't need one. Should I add a method similar to VaeWrapper or just return the sample_latents?

@kmaziarz
Copy link
Collaborator Author

Now the confusion that I have is that GeneratorWrapper doesn't have encode method. From the sample function already defined, it looks like GeneratorWrapper doesn't need one. Should I add a method similar to VaeWrapper or just return the sample_latents?

GeneratorWrapper should not have encode, as it represents latent-space-less models that can only sample (that's why it needs to be a separate class, because the API is more limited). Anywhere encode is called (e.g. cli/encode.py) we should keep using the VaeWrapper, while e.g. in cli/sample.py we can use the generic load_model_from_directory.

@kmaziarz
Copy link
Collaborator Author

@anamika-yadav99: So, summing up, for now let's use the generic way of loading the wrapper for sample.py only. Technically some modes of visualization would also work for MoLeRGenerator, but the visualizer is a bit of a work-in-progress at the moment, so I would leave it out for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants