Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular loading from pretrained #3305

Merged
merged 20 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions deepchem/models/tests/test_modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_load_freeze_unfreeze():
d_hidden = 3
n_layers = 1
ft_tasks = 6
pt_tasks = 3
pt_tasks = 6

X_ft = np.random.rand(n_samples, n_feat)
y_ft = np.random.rand(n_samples, ft_tasks).astype(np.float32)
Expand All @@ -176,8 +176,8 @@ def test_load_freeze_unfreeze():

example_pretrainer.fit(dataset_pt, nb_epoch=1000)

example_model.load_pretrained_components(source_model=example_pretrainer,
components=['encoder'])
example_model.load_from_pretrained(model_dir=example_pretrainer.model_dir,
components=['encoder'])

example_model.freeze_components(['encoder'])

Expand All @@ -201,3 +201,6 @@ def test_load_freeze_unfreeze():
assert not np.array_equal(
example_pretrainer.components['encoder'][0].weight.data.cpu().numpy(),
example_model.components['encoder'][0].weight.data.cpu().numpy())


test_load_freeze_unfreeze()
71 changes: 32 additions & 39 deletions deepchem/models/torch_models/infograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,63 +563,56 @@ def build_components(self):
if self.task == 'supervised':
return {
'encoder':
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
'unsup_encoder':
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
'ff1':
MultilayerPerceptron(2 * self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,)),
MultilayerPerceptron(2 * self.embedding_dim, self.embedding_dim,
(self.embedding_dim,)),
'ff2':
MultilayerPerceptron(2 * self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,)),
MultilayerPerceptron(2 * self.embedding_dim, self.embedding_dim,
(self.embedding_dim,)),
'fc1':
torch.nn.Linear(2 * self.embedding_dim, self.embedding_dim),
torch.nn.Linear(2 * self.embedding_dim, self.embedding_dim),
'fc2':
torch.nn.Linear(self.embedding_dim, 1),
torch.nn.Linear(self.embedding_dim, 1),
'local_d':
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True),
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True),
'global_d':
MultilayerPerceptron(2 * self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True)
MultilayerPerceptron(2 * self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True)
}
elif self.task == 'semisupervised':
return {
'encoder':
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
InfoGraphEncoder(self.num_features, self.edge_features,
self.embedding_dim),
'unsup_encoder':
GINEncoder(self.num_features, self.embedding_dim,
self.num_gc_layers),
GINEncoder(self.num_features, self.embedding_dim,
self.num_gc_layers),
'ff1':
MultilayerPerceptron(2 * self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,)),
MultilayerPerceptron(2 * self.embedding_dim, self.embedding_dim,
(self.embedding_dim,)),
'ff2':
MultilayerPerceptron(self.embedding_dim, self.embedding_dim,
(self.embedding_dim,)),
MultilayerPerceptron(self.embedding_dim, self.embedding_dim,
(self.embedding_dim,)),
'fc1':
torch.nn.Linear(2 * self.embedding_dim, self.embedding_dim),
torch.nn.Linear(2 * self.embedding_dim, self.embedding_dim),
'fc2':
torch.nn.Linear(self.embedding_dim, 1),
torch.nn.Linear(self.embedding_dim, 1),
'local_d':
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True),
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True),
'global_d':
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True)
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True)
}

def build_model(self):
Expand Down
199 changes: 126 additions & 73 deletions deepchem/models/torch_models/modular.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
import copy
import os
from collections.abc import Sequence as SequenceCollection
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Sequence
import torch
Expand Down Expand Up @@ -62,7 +62,7 @@ class ModularTorchModel(TorchModel):
... return (torch.nn.functional.mse_loss(pretrain_model(inputs), labels[0]) * weights[0]).mean()
>>> pretrain_modular_model.loss_func = example_pt_loss_func
>>> pt_loss = pretrain_modular_model.fit(dataset_pt, nb_epoch=1)
>>> modular_model.load_pretrained_components(pretrain_modular_model, components=['linear'])
>>> modular_model.load_from_pretrained(pretrain_modular_model, components=['linear'])
rbharath marked this conversation as resolved.
Show resolved Hide resolved
>>> ft_loss = modular_model.fit(dataset_ft, nb_epoch=1)

"""
Expand Down Expand Up @@ -131,77 +131,6 @@ def unfreeze_components(self, components: List[str]):
for param in self.components[component].parameters():
param.requires_grad = True

def load_pretrained_components(
self,
source_model: Optional['ModularTorchModel'] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None,
components: Optional[list] = None) -> None:
"""Modifies the TorchModel load_from_pretrained method to allow for loading
from a ModularTorchModel and specifying which components to load.

If the user does not a specify a source model, a checkpoint is used to load
the weights. In this case, the user cannot specify which components to load
because the components are not stored in the checkpoint. All layers will
then be loaded if they have the same name and shape. This can cause issues
if a pretrained model has similar but not identical layers to the model where
a user may expect the weights to be loaded. ModularTorchModel subclasses
should be written such that the components are atomic and will be preserved
across as many tasks as possible. For example, an encoder may have varying
input dimensions for different datasets, so the encoder should be written
such that the input layer is not included in the encoder, allowing the
encoder to be loaded with any input dimension.

Parameters
----------
source_model: Optional[ModularTorchModel]
The model to load the weights from.
checkpoint: Optional[str]
The path to the checkpoint to load the weights from.
model_dir: Optional[str]
The path to the directory containing the checkpoint to load the weights.
components: Optional[list]
The components to load the weights from. If None, all components will be
loaded.
"""

# generate the source state dict
if source_model is not None:
source_state_dict = source_model.model.state_dict()
elif checkpoint is not None:
source_state_dict = torch.load(checkpoint)['model_state_dict']
elif model_dir is not None:
checkpoints = sorted(self.get_checkpoints(model_dir))
source_state_dict = torch.load(checkpoints[0])['model_state_dict']
else:
raise ValueError(
"Must provide a source model, checkpoint, or model_dir")

if components is not None: # load the specified components
if source_model is not None:
assignment_map = {
k: v
for k, v in source_model.components.items()
if k in components
}
assignment_map_copy = copy.deepcopy(
assignment_map) # deep copy to avoid modifying source_model
self.components.update(assignment_map_copy)
self.model = self.build_model()
else:
raise ValueError(
"If loading from checkpoint, you cannot pass a list of components to load"
)
else: # or all components with matching names and shapes
model_dict = self.model.state_dict()
assignment_map = {
k: v
for k, v in source_state_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(assignment_map)
self.model.load_state_dict(model_dict)

def fit_generator(self,
generator: Iterable[Tuple[Any, Any, Any]],
max_checkpoints_to_keep: int = 5,
Expand Down Expand Up @@ -342,3 +271,127 @@ def fit_generator(self,
time2 = time.time()
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
return last_avg_loss

def load_from_pretrained( # type: ignore
self,
source_model: Optional["ModularTorchModel"] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all three required - source_model, checkpoint, model_dir? I think only model_dir will be sufficient. Given a model_dir, the method loads the state_dict and if any of the current models layer or component matches the keys in state_dict, the method can update those components weights.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's relatively harmless to support multiple loading options here. Gives maybe a bit more flexibility to users

components: Optional[List[str]] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None,
inputs: Optional[Sequence[Any]] = None,
**kwargs) -> None:
"""Copies parameter values from a pretrained model. The pretrained model can be loaded as a source_model (ModularTorchModel object), checkpoint (pytorch .ckpt file) or a model_dir (directory with .ckpt files).
Specific components can be chosen by passing a list of strings with the desired component names. If both a source_model and a checkpoint/model_dir are loaded, the source_model weights will be loaded.

Parameters
----------
source_model: dc.ModularTorchModel, required
rbharath marked this conversation as resolved.
Show resolved Hide resolved
source_model can either be the pretrained model or a dc.TorchModel with
the same architecture as the pretrained model. It is used to restore from
a checkpoint, if value_map is None and to create a default assignment map
if assignment_map is None
checkpoint: str, default None
the path to the checkpoint file to load. If this is None, the most recent
checkpoint will be chosen automatically. Call get_checkpoints() to get a
list of all available checkpoints
model_dir: str, default None
Restore source model from custom model directory if needed
inputs: List, input tensors for model
if not None, then the weights are built for both the source and self.
"""
if inputs is not None:
# Ensure weights for both models are built.
if source_model:
source_model.model(inputs)
self.model(inputs)

self._ensure_built()

if source_model is not None:
for name, module in source_model.components.items():
if components is None or name in components:
self.components[name] = module
self.build_model()

elif source_model is None:
self.restore(components=components,
checkpoint=checkpoint,
model_dir=model_dir)

def save_checkpoint(self, max_checkpoints_to_keep=5, model_dir=None):
rbharath marked this conversation as resolved.
Show resolved Hide resolved
"""
Saves the current state of the model and its components as a checkpoint file in the specified model directory.
It maintains a maximum number of checkpoint files, deleting the oldest one when the limit is reached.

Parameters
----------
max_checkpoints_to_keep: int, default 5
Maximum number of checkpoint files to keep.
model_dir: str, default None
The directory to save the checkpoint file in. If None, the model_dir specified in the constructor is used.
"""

if model_dir is None:
model_dir = self.model_dir
if not os.path.exists(model_dir):
os.makedirs(model_dir)

data = {
'model': self.model.state_dict(),
'optimizer_state_dict': self._pytorch_optimizer.state_dict(),
'global_step': self._global_step
}

for name, component in self.components.items():
data[name] = component.state_dict()

temp_file = os.path.join(model_dir, 'temp_checkpoint.pt')
torch.save(data, temp_file)

# Rename and delete older files.

paths = [
os.path.join(model_dir, 'checkpoint%d.pt' % (i + 1))
for i in range(max_checkpoints_to_keep)
]
if os.path.exists(paths[-1]):
os.remove(paths[-1])
for i in reversed(range(max_checkpoints_to_keep - 1)):
if os.path.exists(paths[i]):
os.rename(paths[i], paths[i + 1])
os.rename(temp_file, paths[0])

def restore( # type: ignore
self,
components: Optional[List[str]] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None) -> None:
rbharath marked this conversation as resolved.
Show resolved Hide resolved
"""
Restores the state of a ModularTorchModel from a checkpoint file.

If no checkpoint file is provided, it will use the latest checkpoint found in the model directory. If a list of component names is provided, only the state of those components will be restored.

Parameters
----------
components: Optional[List[str]]
A list of component names to restore. If None, all components will be restored.
checkpoint: Optional[str]
The path to the checkpoint file. If None, the latest checkpoint in the model directory will
be used.
model_dir: Optional[str]
The path to the model directory. If None, the model directory used to initialize the model will be used.
"""
if checkpoint is None:
checkpoints = sorted(self.get_checkpoints(model_dir))
if len(checkpoints) == 0:
raise ValueError('No checkpoint found')
checkpoint = checkpoints[0]
data = torch.load(checkpoint)
for name, state_dict in data.items():
if name != 'model' and name in self.components.keys():
if components is None or name in components:
self.components[name].load_state_dict(state_dict)

self.build_model()
self._pytorch_optimizer.load_state_dict(data['optimizer_state_dict'])
self._global_step = data['global_step']