-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added infograph model finetuning support #3491
Conversation
29b796e
to
bc90b17
Compare
@tonydavis629 , just want to check in - does it make sense to add finetuning support for InfoGraph model? From my understanding, the infograph paper performed two kind of experiments:
Am I right in the understanding here? With the current setup, we can do exp 2 and I believe this code adds support for exp 1. But is there a better way to do it? |
@arunppsg I believe this change is redundant, the functionality to do pretraining and finetuning is already there in InfoGraphModel and InfoGraphStarModel. The difference being that this code uses 1 combined InfoGraph model for finetuning and pretraining, while the current implementation uses InfoGraph for pretraining and InfoGraphStar for finetuning. The 2 experiments done in the paper are 1. unsupervised mutual information maximization between a global and local encoded graph representation (with no finetuning or supervised objective) and 2. layer by layer mutual information maximization between a trained encoder and untrained encoder plus a supervised loss (unsupervised objective + supervised objective being referred to in the paper as semi-supervised). So the InfoGraphModel is reserved for this #1 unsupervised task to train the encoder. InfoGraphStarModel is to be used for #2 by loading the weights of pretrained InfoGraph into the encoder. Your InfoGraphFinetune functionality is very similar to InfoGraphStarModel. This pretrain + finetune regime is implemented in |
I agree that InfoGraph* does a semi-supervised finetuning and infograph doing a unsupervised learning to train an encoder. From the table 2 in the paper, they have two results: results from InfoGraph model finetuned via supervised setting and results from InfoGraph* finetuned via semi-supervised learning approach. In this pull request, I am proposing an approach for finetuning the encoder in a supervised setting. It it will be useful, we can merge it in else we can close it. |
Figure 2 shows only InfoGraph*, while figure 1 shows InfoGraph. Those 2 approaches were implemented with InfoGraphStarModel and InfoGraphModel. Finetuning the encoder with a supervised dataset is already possible with InfoGraphStarModel. The difference I see in your implementation is that you've combined both into a single model, which may be more convenient, but otherwise I believe it is the same functionality. |
I leave it to @rbharath for final call on whether it will be useful or not to users. |
This is redundant with InfographStar, but I think it could be nice for users since our other models allow for pretraining/finetuning in the same model and it's convenient to have the same for infograph. @arunppsg Let me know once all tests are passing and this is ready for my full review |
This is ready for review and tests are passing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -194,8 +195,8 @@ def __init__(self, | |||
if device is None: | |||
if torch.cuda.is_available(): | |||
device = torch.device('cuda') | |||
elif torch.backends.mps.is_available(): | |||
device = torch.device('mps') | |||
# elif torch.backends.mps.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove this before merging in? Looks like cruft
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, forgot to remove. removed it.
@@ -386,6 +385,7 @@ def restore( # type: ignore | |||
model_dir: Optional[str] | |||
The path to the model directory. If None, the model directory used to initialize the model will be used. | |||
""" | |||
logger.info('Restoring model') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These logger changes are also reflected in your other PR. I'm fine merging them in as part of this PR since they are small
Description
Added support for infograph model finetuning task.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors