Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning pipeline #414

Merged
merged 57 commits into from
Aug 9, 2023
Merged

Finetuning pipeline #414

merged 57 commits into from
Aug 9, 2023

Conversation

WenkelF
Copy link
Collaborator

@WenkelF WenkelF commented Jul 21, 2023

@DomInvivo as discussed a first draft for the Finetuning pipeline.

Two possible pipelines:
expts/main_run_finetuning_v1.py (probably to be removed):

  • Start with pretrained model
  • Remove redundant task heads
  • Drop/add layers of model part to be finetuned
  • Drawback: Difficult updating of hyperparameters

expts/main_run_finetuning_v2.py:

  • Initialize new model with final finetuning architecture
  • Overwrite parameters that are shared with pretrained model

Main TODOs:

  • Use hydra to easy finetuning config
  • Create unit test with dummy pretrained model

@codecov
Copy link

codecov bot commented Jul 23, 2023

Codecov Report

Merging #414 (6625e0c) into main (9515ad6) will decrease coverage by 2.14%.
The diff coverage is 20.51%.

@@            Coverage Diff             @@
##             main     #414      +/-   ##
==========================================
- Coverage   66.87%   64.74%   -2.14%     
==========================================
  Files          82       89       +7     
  Lines        7838     8211     +373     
==========================================
+ Hits         5242     5316      +74     
- Misses       2596     2895     +299     
Flag Coverage Δ
unittests 64.74% <20.51%> (-2.14%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Components Coverage Δ
ipu 49.14% <ø> (ø)

Copy link
Collaborator

@DomInvivo DomInvivo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my review. It mostly aligns with what we discussed last time

expts/finetune_configs/config_toy_finetuning_v1.yaml Outdated Show resolved Hide resolved
expts/finetune_configs/config_toy_finetuning_v1.yaml Outdated Show resolved Hide resolved
expts/finetune_configs/config_toy_finetuning_v2.yaml Outdated Show resolved Hide resolved
expts/finetune_configs/config_toy_finetuning_v2.yaml Outdated Show resolved Hide resolved
expts/finetune_configs/config_toy_finetuning_v2.yaml Outdated Show resolved Hide resolved
graphium/finetuning/finetuning.py Outdated Show resolved Hide resolved
graphium/finetuning/finetuning.py Outdated Show resolved Hide resolved
from copy import deepcopy


def modify_cfg_for_finetuning(cfg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a function within FeedForwardNN and FullGraphNetwork

Copy link
Collaborator Author

@WenkelF WenkelF Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, we can remove the function and replace inside the networks. As discussed in that #411, it might be possible to have a similar function within load_architecture in graphium/data/_loader.py but depends on the PR.

graphium/nn/architectures/global_architectures.py Outdated Show resolved Hide resolved
graphium/nn/architectures/global_architectures.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@cwognum cwognum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @WenkelF, sorry for the delay here!

Thank you for this first implementation!

I left comments whenever something stood out to me, but I'm aware that this PR is still WIP. Sorry if I pointed out some things that you were already planning to change.

It would be super helpful if you could document the main fine-tuning "flow". Furthermore, I would suggest to simplify the process by adding support for one feature at a time, instead of having half-implemented features. This will make understanding, debugging, testing and maintaining the code base a lot easier.

Happy to help next week!

expts/main_run_finetuning.py Outdated Show resolved Hide resolved
expts/main_run_finetuning.py Outdated Show resolved Hide resolved
expts/main_run_finetuning.py Outdated Show resolved Hide resolved
expts/main_run_finetuning.py Outdated Show resolved Hide resolved
expts/main_run_finetuning.py Outdated Show resolved Hide resolved
graphium/finetuning/utils.py Outdated Show resolved Hide resolved
graphium/nn/architectures/global_architectures.py Outdated Show resolved Hide resolved
graphium/nn/architectures/global_architectures.py Outdated Show resolved Hide resolved
graphium/nn/architectures/global_architectures.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@DomInvivo DomInvivo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget that the objective is to release a working, incomplete version first. Then refining it with more complex fine-tuning possibilities.

Comment on lines 75 to 103
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. We should not have any architectural choice from the original pre-trained model in here. Only things that would change.

That way, we can take different pre-trained models that have different hparams/seed and fine-tune them all with the same file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The configurations are still structured in a way where we have access to both the full config of the pretrained model and the pretraining-related config. And the modify_cfg_for_finetuning function consolidates information to one config.

This will be fixed once we incorporate the new hydra config from #421. We will still need modify_cfg_for_finetuning as of now. Therefore, it could be good waiting for the final version.

expts/main_run_finetuning.py Outdated Show resolved Hide resolved
expts/main_run_finetuning.py Outdated Show resolved Hide resolved
Comment on lines 1964 to 1968
try:
if "epoch_sampling_fraction" in args[task].keys():
args[task].pop("epoch_sampling_fraction")
except:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even understand the point of having a try and an if there. Dict.pop works even if the key is not available.

But we need to make sure that args[task] is not used elsewhere, even outside the current function since dict are passed as pointers. We only want to remove epoch_sampling for the hash key. So I would suggest the following.

Suggested change
try:
if "epoch_sampling_fraction" in args[task].keys():
args[task].pop("epoch_sampling_fraction")
except:
pass
args[task] = deepcopy(args[task])
args[task].pop("epoch_sampling_fraction")

Copy link
Collaborator Author

@WenkelF WenkelF Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is not an ideal fix. I wanted to investigate the issue a bit more.

We cannot use args[task].pop("epoch_sampling_fraction") because args[task] may be of class DatasetProcessingParams instead of Dict. In particular, ADMETBenchmarkDataModule makes use of DatasetProcessingParams.

The error originates from changes here 4b82ba3, where the line args[task].pop("epoch_sampling_fraction") was added. It did not cause errors back then because we were using a Dict in all configs (although I see a comment # To be replaced by a new class "DatasetParams" everywhere it appears).

Will create issue and think about a fix.

graphium/data/datamodule.py Outdated Show resolved Hide resolved


class GraphFinetuning(BaseFinetuning):
def __init__(self, cfg, train_bn: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to explicitly pass parameters.

Suggested change
def __init__(self, cfg, train_bn: bool = False):
def __init__(self, fine-tuning, architecture, module_from_pretrained, ....................., train_bn: bool = False):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 76e2ba6

modules = pl_module.model.task_heads.graph_output_nn
elif module == "task_heads":
modules = pl_module.model.task_heads.task_heads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else: raise "Wrong module"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 76e2ba6

Comment on lines 45 to 56
if module == "pe_encoders":
modules = pl_module.model.encoder_manager
elif module == "pre_nn":
modules = pl_module.model.pre_nn
elif module == "pre_nn_edges":
modules = pl_module.model.pre_nn_edges
elif module == "gnn":
modules = pl_module.model.gnn
elif module == "graph_output_nn":
modules = pl_module.model.task_heads.graph_output_nn
elif module == "task_heads":
modules = pl_module.model.task_heads.task_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would define all these in a dictionary _module_map = {pe_encoders: pl_module.model.encoder_manager, ...} directly in the __init__. That way, with inheritance, someone could modified the entries without copy-pasting all the logic.

_module_map can replace the module_list you already have.

But instead of a regular dict, using an OrderedDict would also allow you to say something like: "freeze everything before task_heads" in a very simple way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bad idea to have a FullGraphFinetuningNetwork that basically copy-pastes most of the functionality of FullGraphNetwork.

Either use inheritance, or implement the fine-tuning logic directly within FullGraphNetwork

@WenkelF
Copy link
Collaborator Author

WenkelF commented Jul 29, 2023

@DomInvivo thanks for your comments. Here is also a quick overview of the updates

Updates:

  • Moved finetuning-related architecture out of FullGraphMultitaskNetwork to FullGraphFinetuningNetwork under graphium/finetuning/finetuning_architecture
    • More flexibility to build finetuning head on top of pretrained model
    • For now, finetuning is specific to FullGraphMultitaskNetwork (as pretrained model), but we can finetune from gnn, graph_output_nn and task_heads now
    • Additionally, we can add a flexible FinetuningHead, inheriting from nn.Module and MupMixin. For example, it can be a FeedForwardNN, FeedForwardPyg, TaskHeads, or a custom network that can be added to FINETUNING_HEAD_DICT
    • FinetuningHead needs a bit more work to correctly support all possible scenarios (modules/levels to finetune from); if not used, we automatically fall back to the finetuning logic, where we can only drop & add depth to the module we finetune from
  • Finished functions for overwriting shared weights and partially freezing modules while training and extended them from only task_heads to also gnn and graph_output_nn
  • Added preliminary unit test the discussed “minimal” finetuning pipeline (finetune from one of the task heads)
    • We currently test depths, in_dim, ou_dim of changed modules are correct and overwritten correctly
    • Test does not cover training/freezing yet
    • Planning to add tests for finetuning from other modules as well
  • Configuration has been integrated with hydra but needs more work to be easier to use (@luis-mueller is helping with that)

Remarks:

  • Hidden_dims can only be used as int together with depth at the moment
  • Currently, everything has only been tested on cpu/gpu

@luis-mueller luis-mueller mentioned this pull request Jul 30, 2023
5 tasks
@WenkelF
Copy link
Collaborator Author

WenkelF commented Aug 4, 2023

@DomInvivo this pull introduces the following:

  • Finetuning pipeline
  • Updates to hydra configuration (by @cwognum )

The finetuning pipeline is maintained separately from existing architectures under graphium/finetuning/finetuning_architecture. The FullGraphFinetuningNetwork comes with two submodules:

  • PretrainedModel
  • FinetuningHead (optional)

PretrainedModel can load pretrained models (e.g., FullGraphMultitaskNetwork) and allows basic finetuning operations (in particular from the task heads), including dropping or extending layers of the module to finetune from.

FinetuningHead is optional and allows to define fully customizable networks that are applied on top of PretrainedModel. If not specified, we fall back to basic finetuning explained above.

Training is handled by the GraphFinetuning callback in graphium/finetuning/finetuning.

All methods in graphium/finetuning are implemented such that they are not specific to a pretrained model or finetuning head. This is achieved by requiring the pretrained model to come with a module_map (see, e.g., FullGraphMultitaskNetwork) that facilitates setting pretrained weights and freezing the correct layers.

The new unit test tests/test_finetuning tests the pipeline end-to-end for a specific case. Together with the corresponding config, it might be a good starting point to understand the pipeling and what information is needed in the config.

The updates to hydra allow to easily switch between benchmarking and finetuning. Major changes are documented here WenkelF#10

Copy link
Collaborator

@DomInvivo DomInvivo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, great work on this major PR!
A few changes to make, and some comments.

expts/hydra-configs/finetuning/finetuning.yaml Outdated Show resolved Hide resolved
env.yml Outdated Show resolved Hide resolved
expts/hydra-configs/finetuning/admet.yaml Outdated Show resolved Hide resolved
expts/hydra-configs/tasks/admet.yaml Outdated Show resolved Hide resolved
graphium/cli/finetune.py Outdated Show resolved Hide resolved
Comment on lines 99 to 105
if "task_heads_kwargs" in model_kwargs.keys():
task_heads_kwargs = model_kwargs["task_heads_kwargs"]
elif "pretrained_model_kwargs" in model_kwargs.keys():
# This covers finetuning cases where we finetune from the task_heads
task_heads_kwargs = model_kwargs["pretrained_model_kwargs"]["task_heads_kwargs"]
else:
raise ValueError("incorrect model_kwargs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be here. I think that, if you are using a pre-trained model, you should pass directly model_kwargs["pretrained_model_kwargs"] into model_kwargs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in a3d4715 as explained below

Comment on lines 111 to 113
task_level=task_heads_kwargs[key]["task_level"],
task=key
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, the PredictorModule should be agnostic to the model passed. By having the self._get_task_key here, it forces a certain architecture in the config which is not very flexible.

I see that this logic was introduced already in the code prior to this PR. If it requires too many changes, let's open a new issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks for pointing out.

We could achieve this by getting the task-specific information (which is only the task level as far as I know) from the datamodule.

Here is how this could be done:
a3d4715

What do you think?

graphium/trainer/predictor.py Outdated Show resolved Hide resolved
Comment on lines 135 to 136
task_level=task_heads_kwargs[key]["task_level"],
task=key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, don't like how the config structure is imposed. Perhaps task_level should simply be passed to the PredictorModule to keep flexibility

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, this is implemented in a3d4715

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why duplicating the model for CPU and GPU? Models should be agnostic to the training hardware, and to the fine-tuning hardware. You can train on CPU and fine-tune on GPU or IPU

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. I only changed from gpu to cpu because github cannot do unit tests on gpu. Should I remove the gpu model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can remove the gpu model.

@DomInvivo
Copy link
Collaborator

@zhiyil-graphcore @s-maddrellmander We'll need your help here to fix the tests for IPU. And ideally, have a test that loads a CPU-trained model onto IPU for finetuning.

@WenkelF
Copy link
Collaborator Author

WenkelF commented Aug 8, 2023

@cwognum the bug in the finetuning training is fixed here febdf2d

I missed a deepcopy operation when defining the datahashes for the TDC datasets. We include the first 5 rows of the df when generating the hash and the bug reduced the datasets to those 5 rows (molecules) as well.

Make sure to remove the TDC datasets from datacache. You can use datamodule._path_to_load_from_file("train) to get "train_[data_hash]" and then remove all data hashes (e.g., "train_[data_hash]", "test_[data_hash]", etc. that end with the same sequence.

@WenkelF
Copy link
Collaborator Author

WenkelF commented Aug 9, 2023

@DomInvivo I added some final improvements 24354ee

  1. Made modify_cfg_for_finetuning function fully agnostic to pretrained model
  2. Dropped dummy-pretrained-model-gpu and added map_location as an argument to load_pretrained_models function
  3. Added the option to keep modules after finetuning module

(3.) makes it much easier to finetune from modules other than the task_heads without manually re-defining the downstream network. When finetuning from task_heads, it is not needed.

s-maddrellmander added a commit that referenced this pull request Aug 9, 2023
Merging the IPU tests so the ipu CLI test is in the correct environment for major updates to workflow in #414
@s-maddrellmander
Copy link
Contributor

@WenkelF - try merging from the main branch, I've made a change to the IPU test CLI that should take into account the changes made in this PR. If that doesn't work let me know.

@DomInvivo DomInvivo merged commit 617356d into datamol-io:main Aug 9, 2023
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimizing main_run_multitask Removing epoch_sampling_fraction does not support DatasetProcessingParams
6 participants