/
modular.py
405 lines (355 loc) · 17.9 KB
/
modular.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import time
import logging
import os
from collections.abc import Sequence as SequenceCollection
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.optimizers import LearningRateSchedule
from deepchem.utils.typing import LossFn, OneOrMany
logger = logging.getLogger(__name__)
class ModularTorchModel(TorchModel):
"""ModularTorchModel is a subclass of TorchModel that allows for components to be
pretrained and then combined into a final model. It is designed to be subclassed
for specific models and is not intended to be used directly. There are 3 main differences
between ModularTorchModel and TorchModel:
- The build_components() method is used to define the components of the model.
- The components are combined into a final model with the build_model() method.
- The loss function is defined with the loss_func method. This may access the
components to compute the loss using intermediate values from the network, rather
than just the full forward pass output.
Here is an example of how to use ModularTorchModel to pretrain a linear layer, load
it into another network and then finetune that network:
>>> import numpy as np
>>> import deepchem as dc
>>> import torch
>>> n_samples = 6
>>> n_feat = 3
>>> n_hidden = 2
>>> n_tasks = 6
>>> pt_tasks = 3
>>> X = np.random.rand(n_samples, n_feat)
>>> y_pretrain = np.zeros((n_samples, pt_tasks)).astype(np.float32)
>>> dataset_pt = dc.data.NumpyDataset(X, y_pretrain)
>>> y_finetune = np.zeros((n_samples, n_tasks)).astype(np.float32)
>>> dataset_ft = dc.data.NumpyDataset(X, y_finetune)
>>> components = {'linear': torch.nn.Linear(n_feat, n_hidden),
... 'activation': torch.nn.ReLU(), 'head': torch.nn.Linear(n_hidden, n_tasks)}
>>> model = torch.nn.Sequential(components['linear'], components['activation'],
... components['head'])
>>> modular_model = dc.models.torch_models.modular.ModularTorchModel(model, components)
>>> def example_loss_func(inputs, labels, weights):
... return (torch.nn.functional.mse_loss(model(inputs), labels[0]) * weights[0]).mean()
>>> modular_model.loss_func = example_loss_func
>>> def example_model_build():
... return torch.nn.Sequential(components['linear'], components['activation'],
... components['head'])
>>> modular_model.build_model = example_model_build
>>> pretrain_components = {'linear': torch.nn.Linear(n_feat, n_hidden),
... 'activation': torch.nn.ReLU(), 'head': torch.nn.Linear(n_hidden, pt_tasks)}
>>> pretrain_model = torch.nn.Sequential(pretrain_components['linear'],
... pretrain_components['activation'], pretrain_components['head'])
>>> pretrain_modular_model = dc.models.torch_models.modular.ModularTorchModel(pretrain_model,
... pretrain_components)
>>> def example_pt_loss_func(inputs, labels, weights):
... return (torch.nn.functional.mse_loss(pretrain_model(inputs), labels[0]) * weights[0]).mean()
>>> pretrain_modular_model.loss_func = example_pt_loss_func
>>> pt_loss = pretrain_modular_model.fit(dataset_pt, nb_epoch=1)
>>> modular_model.load_from_pretrained(pretrain_modular_model, components=['linear'])
>>> ft_loss = modular_model.fit(dataset_ft, nb_epoch=1)
"""
def __init__(self, model: nn.Module, components: dict, **kwargs):
"""Create a ModularTorchModel.
Parameters
----------
model: nn.Module
The model to be trained.
components: dict
A dictionary of the components of the model. The keys are the names of the
components and the values are the components themselves.
"""
self.model = model
self.components = components
# FIXME self.loss_func is an incorrect argument for TorchModel.loss because
# it performs more than computing loss
super().__init__(self.model, self.loss_func, **kwargs)
self.model.to(self.device)
self.components = {
k: v.to(self.device) if isinstance(v, nn.Module) else v
for k, v in self.components.items()
}
def build_model(self) -> nn.Module:
"""Builds the final model from the components."""
raise NotImplementedError("Subclass must define the components")
def build_components(self) -> dict:
"""Creates the components dictionary, with the keys being the names of the
components and the values being torch.nn.module objects."""
raise NotImplementedError("Subclass must define the components")
def loss_func(self, inputs: OneOrMany[torch.Tensor], labels: Sequence,
weights: Sequence) -> torch.Tensor:
"""Defines the loss function for the model which can access the components
using self.components. The loss function should take the inputs, labels, and
weights as arguments and return the loss."""
raise NotImplementedError("Subclass must define the loss function")
def freeze_components(self, components: List[str]):
"""Freezes or unfreezes the parameters of the specified components.
Components string refers to keys in self.components.
Parameters
----------
components: List[str]
The components to freeze.
"""
for component in components:
for param in self.components[component].parameters():
param.requires_grad = False
def unfreeze_components(self, components: List[str]):
"""Unfreezes the parameters of the specified components.
Components string refers to keys in self.components.
Parameters
----------
components: List[str]
The components to unfreeze.
"""
for component in components:
for param in self.components[component].parameters():
param.requires_grad = True
def fit_generator(self,
generator: Iterable[Tuple[Any, Any, Any]],
max_checkpoints_to_keep: int = 5,
checkpoint_interval: int = 1000,
restore: bool = False,
variables: Optional[Union[List[torch.nn.Parameter],
torch.nn.ParameterList]] = None,
loss: Optional[LossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[List[float]] = None) -> float:
"""Train this model on data from a generator. This method is similar to
the TorchModel implementation, but it passes the inputs directly to the
loss function, rather than passing them through the model first. This
enables the loss to be calculated from intermediate steps of the model
and not just the final output.
Parameters
----------
generator: generator
this should generate batches, each represented as a tuple of the form
(inputs, labels, weights).
max_checkpoints_to_keep: int
the maximum number of checkpoints to keep. Older checkpoints are discarded.
checkpoint_interval: int
the frequency at which to write checkpoints, measured in training steps.
Set this to 0 to disable automatic checkpointing.
restore: bool
if True, restore the model from the most recent checkpoint and continue training
from there. If False, retrain the model from scratch.
variables: list of torch.nn.Parameter
the variables to train. If None (the default), all trainable variables in
the model are used.
loss: function
a function of the form f(outputs, labels, weights) that computes the loss
for each batch. If None (the default), the model's standard loss function
is used.
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
all_losses: Optional[List[float]], optional (default None)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Returns
-------
The average loss over the most recent checkpoint interval
"""
if not isinstance(callbacks, SequenceCollection):
callbacks = [callbacks]
self._ensure_built()
self.model.train()
avg_loss = 0.0
last_avg_loss = 0.0
averaged_batches = 0
# FIXME This line is not needed as loss is computed inside the call to loss_func
if loss is None:
loss = self._loss_fn
if variables is None:
optimizer = self._pytorch_optimizer
lr_schedule = self._lr_schedule
else:
var_key = tuple(variables)
if var_key in self._optimizer_for_vars:
optimizer, lr_schedule = self._optimizer_for_vars[var_key]
else:
optimizer = self.optimizer._create_pytorch_optimizer(variables)
if isinstance(self.optimizer.learning_rate,
LearningRateSchedule):
lr_schedule = self.optimizer.learning_rate._create_pytorch_schedule(
optimizer)
else:
lr_schedule = None
self._optimizer_for_vars[var_key] = (optimizer, lr_schedule)
time1 = time.time()
# Main training loop.
for batch in generator:
if restore:
self.restore()
restore = False
inputs: OneOrMany[torch.Tensor]
inputs, labels, weights = self._prepare_batch(batch)
# Execute the loss function, accumulating the gradients.
if isinstance(inputs, list) and len(inputs) == 1:
inputs = inputs[0]
optimizer.zero_grad()
batch_loss = self.loss_func(inputs, labels, weights)
batch_loss.backward()
optimizer.step()
if lr_schedule is not None:
lr_schedule.step()
self._global_step += 1
current_step = self._global_step
avg_loss += float(batch_loss)
# Report progress and write checkpoints.
averaged_batches += 1
should_log = (current_step % self.log_frequency == 0)
if should_log:
avg_loss = float(avg_loss) / averaged_batches
logger.info('Ending global_step %d: Average loss %.10f' %
(current_step, avg_loss))
if all_losses is not None:
all_losses.append(avg_loss)
# Capture the last avg_loss in case of return since we're resetting to 0 now
last_avg_loss = avg_loss
avg_loss = 0.0
averaged_batches = 0
if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
self.save_checkpoint(max_checkpoints_to_keep)
for c in callbacks:
c(self, current_step)
if self.tensorboard and should_log:
self._log_scalar_to_tensorboard('loss', batch_loss,
current_step)
if (self.wandb_logger is not None) and should_log:
all_data = dict({'train/loss': batch_loss})
self.wandb_logger.log_data(all_data, step=current_step)
# Report final results.
if averaged_batches > 0:
avg_loss = float(avg_loss) / averaged_batches
logger.info('Ending global_step %d: Average loss %g' %
(current_step, avg_loss))
if all_losses is not None:
all_losses.append(avg_loss)
last_avg_loss = avg_loss
if checkpoint_interval > 0:
self.save_checkpoint(max_checkpoints_to_keep)
time2 = time.time()
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
return last_avg_loss
def load_from_pretrained( # type: ignore
self,
source_model: Optional["ModularTorchModel"] = None,
components: Optional[List[str]] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None,
inputs: Optional[Sequence[Any]] = None,
**kwargs) -> None:
"""Copies parameter values from a pretrained model. The pretrained model can be loaded as a source_model (ModularTorchModel object), checkpoint (pytorch .ckpt file) or a model_dir (directory with .ckpt files).
Specific components can be chosen by passing a list of strings with the desired component names. If both a source_model and a checkpoint/model_dir are loaded, the source_model weights will be loaded.
Parameters
----------
source_model: dc.ModularTorchModel, required
source_model can either be the pretrained model or a dc.TorchModel with
the same architecture as the pretrained model. It is used to restore from
a checkpoint, if value_map is None and to create a default assignment map
if assignment_map is None
checkpoint: str, default None
the path to the checkpoint file to load. If this is None, the most recent
checkpoint will be chosen automatically. Call get_checkpoints() to get a
list of all available checkpoints
model_dir: str, default None
Restore source model from custom model directory if needed
inputs: List, input tensors for model
if not None, then the weights are built for both the source and self.
"""
if inputs is not None:
# Ensure weights for both models are built.
if source_model:
source_model.model(inputs)
self.model(inputs)
self._ensure_built()
if source_model is not None:
for name, module in source_model.components.items():
if components is None or name in components:
self.components[name].load_state_dict(module.state_dict(),
strict=False)
self.build_model()
elif source_model is None:
self.restore(components=components,
checkpoint=checkpoint,
model_dir=model_dir)
def save_checkpoint(self, max_checkpoints_to_keep=5, model_dir=None):
"""
Saves the current state of the model and its components as a checkpoint file in the specified model directory.
It maintains a maximum number of checkpoint files, deleting the oldest one when the limit is reached.
Parameters
----------
max_checkpoints_to_keep: int, default 5
Maximum number of checkpoint files to keep.
model_dir: str, default None
The directory to save the checkpoint file in. If None, the model_dir specified in the constructor is used.
"""
if model_dir is None:
model_dir = self.model_dir
if not os.path.exists(model_dir):
os.makedirs(model_dir)
data = {
'model': self.model.state_dict(),
'optimizer_state_dict': self._pytorch_optimizer.state_dict(),
'global_step': self._global_step
}
for name, component in self.components.items():
if hasattr(component, 'state_dict'):
data[name] = component.state_dict()
temp_file = os.path.join(model_dir, 'temp_checkpoint.pt')
torch.save(data, temp_file)
# Rename and delete older files.
paths = [
os.path.join(model_dir, 'checkpoint%d.pt' % (i + 1))
for i in range(max_checkpoints_to_keep)
]
if os.path.exists(paths[-1]):
os.remove(paths[-1])
for i in reversed(range(max_checkpoints_to_keep - 1)):
if os.path.exists(paths[i]):
os.rename(paths[i], paths[i + 1])
os.rename(temp_file, paths[0])
def restore( # type: ignore
self,
components: Optional[List[str]] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None) -> None:
"""
Restores the state of a ModularTorchModel from a checkpoint file.
If no checkpoint file is provided, it will use the latest checkpoint found in the model directory. If a list of component names is provided, only the state of those components will be restored.
Parameters
----------
components: Optional[List[str]]
A list of component names to restore. If None, all components will be restored.
checkpoint: Optional[str]
The path to the checkpoint file. If None, the latest checkpoint in the model directory will
be used.
model_dir: Optional[str]
The path to the model directory. If None, the model directory used to initialize the model will be used.
"""
logger.info('Restoring model')
if checkpoint is None:
checkpoints = sorted(self.get_checkpoints(model_dir))
if len(checkpoints) == 0:
raise ValueError('No checkpoint found')
checkpoint = checkpoints[0]
data = torch.load(checkpoint)
for name, state_dict in data.items():
if name != 'model' and name in self.components.keys():
if components is None or name in components:
self.components[name].load_state_dict(state_dict)
self.build_model()
self._ensure_built()
self._pytorch_optimizer.load_state_dict(data['optimizer_state_dict'])
self._global_step = data['global_step']