Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a compile_pipeline method to leverage torch.compile for performance optimization, which is a valuable addition. However, the current implementation has a critical flaw in how torch.compile is invoked, as it incorrectly assumes an in-place .compile() method on torch.nn.Module instances. This will lead to a runtime error. I have provided a detailed comment with a suggested fix to address this. Additionally, I've recommended replacing print statements with a standard logger for better maintainability. The other changes to support compilation in various models and pipelines are well-structured.
| repeated_blocks = getattr(model, "_repeated_blocks", None) | ||
| # regional compilation for repeated blocks. | ||
| if repeated_blocks is not None: | ||
| for submod in model.modules(): | ||
| if submod.__class__.__name__ in repeated_blocks: | ||
| submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) | ||
| # compile the whole model. | ||
| else: | ||
| model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) |
There was a problem hiding this comment.
torch.nn.Module instances do not have a .compile() method, and torch.compile() is not an in-place operation. It returns a new compiled module that must be used to replace the original. This implementation will raise an AttributeError.
For whole-model compilation, you need to use setattr(self, name, torch.compile(model, ...)). For regional compilation, you need to recursively traverse the model and replace the submodules.
| repeated_blocks = getattr(model, "_repeated_blocks", None) | |
| # regional compilation for repeated blocks. | |
| if repeated_blocks is not None: | |
| for submod in model.modules(): | |
| if submod.__class__.__name__ in repeated_blocks: | |
| submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) | |
| # compile the whole model. | |
| else: | |
| model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) | |
| repeated_blocks = getattr(model, "_repeated_blocks", None) | |
| # regional compilation for repeated blocks. | |
| if repeated_blocks is not None: | |
| # Recursively find and replace modules to be compiled. | |
| def _replace_modules(module): | |
| for name, child in module.named_children(): | |
| if child.__class__.__name__ in repeated_blocks: | |
| setattr(module, name, torch.compile(child, mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)) | |
| else: | |
| _replace_modules(child) | |
| _replace_modules(model) | |
| # compile the whole model. | |
| else: | |
| setattr(self, name, torch.compile(model, mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)) |
| """ | ||
| compile_models = compile_models or getattr(self, "compilable_models", []) | ||
| if len(compile_models) == 0: | ||
| print("No compilable models in the pipeline. Skip compilation.") |
There was a problem hiding this comment.
Consider using the logging module instead of print for logging information. This provides more flexibility for users of the library to control log levels and output streams. You would need to add import logging at the top of the file. This advice applies to the other print statements in this method as well (lines 361 and 372).
No description provided.