/
modeling.py
1590 lines (1309 loc) · 60.9 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 Google LLC.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Provides an abstraction around deep learning models in DeepVariant.
This class allows us to encapsulate all of the model management, loading,
saving, and data processing in a single place so those details don't spill over
into the more general deepvariant codebase. The key thing we are aiming for here
is to make sure we can easily play with other model architectures without
modifying the surrounding training and evaluation code.
"""
import enum
import itertools
import math
from absl import flags
from absl import logging
from tensorflow import estimator as tf_estimator
from tensorflow.compat.v1 import estimator as tf_compat_v1_estimator
import tensorflow as tf
import tf_slim
from tf_slim.nets import inception_v3
from deepvariant import dv_constants
from deepvariant import tf_utils
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.tpu import tpu_config
from tensorflow.python.tpu import tpu_estimator
from tensorflow.python.tpu import tpu_optimizer
# pylint: enable=g-direct-tensorflow-import
from deepvariant import attention_inception_v3
tf.compat.v1.disable_eager_execution()
flags.DEFINE_float(
'label_smoothing', 1e-6,
'Amount of label smoothing to use. By default this is 0.0001% '
'meaning that we expect a label error at a rate of 1 / 1,000,000')
# Training parameters.
flags.DEFINE_float('learning_rate', 0.064, 'Initial learning rate.')
flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
flags.DEFINE_float('rmsprop_epsilon', 1.0, 'Epsilon term for RMSProp.')
flags.DEFINE_float('learning_rate_decay_factor', 0.94,
'Learning rate decay factor.')
flags.DEFINE_float('num_epochs_per_decay', 2.0,
'Number of epochs after which learning rate decays.')
flags.DEFINE_float('moving_average_decay', 0.9999,
'The decay to use for the moving average.')
flags.DEFINE_integer(
'save_summary_steps', 100, 'Number of steps which must have run before '
'showing summaries.')
flags.DEFINE_integer(
'save_interval_secs', 60 * 10,
'Interval (in seconds) at which the model data '
'should be checkpointed. Set to 0 to disable, -1 to ignore. '
'Exclusive with save_interval_steps.')
flags.DEFINE_integer(
'save_interval_steps', -1, 'Interval (in steps) at which the model data '
'should be checkpointed. Set to 0 to disable, -1 to ignore. '
'Exclusive with save_interval_secs.')
flags.DEFINE_integer(
'seq_type_embedding_size', 200,
'Set the embedding size for the sequencing type embeddings. Default is 200. '
'This flag is only useful when model_name is `inception_v3_embedding`.')
flags.DEFINE_bool(
'allow_warmstart_from_different_num_channels', False,
'If True, always allow warmstarting from model checkpoints '
'that has different number of channels.')
FLAGS = flags.FLAGS
slim = tf_slim
class UnsupportedImageDimensionsError(Exception):
"""Exception indicating the image dimensions aren't supported by our model."""
def binarize(labels, target_class):
"""Binarize labels and predictions.
The labels that are equal to target_class parameter are set to 0, else
set to 1.
Args:
labels: the ground-truth labels for the examples.
target_class: index of the class that is left as zero.
Returns:
Tensor of the same shape as labels.
"""
labels_binary = tf.compat.v1.where(
tf.equal(labels, tf.constant(target_class, dtype=tf.int64)),
tf.zeros_like(labels), tf.ones_like(labels))
return labels_binary
def get_class_recall(labels, predicted_class, target_class):
"""Compute recall from labels and predicted_class for target_class.
Examples with label target_class are positives. Other classes are negatives.
Args:
labels: the ground-truth labels for the examples.
predicted_class: the predicted labels for the examples.
target_class: index of the class that is left as non-zero.
Returns:
Tensor containing the recall value.
"""
labels_binary = binarize(labels, target_class)
predicted_class_binary = binarize(predicted_class, target_class)
return tf.compat.v1.metrics.recall(labels_binary, predicted_class_binary)
def get_class_precision(labels, predicted_class, target_class):
"""Compute precision from labels and predicted_class for target_class.
Examples with label target_class are positives. Other classes are negatives.
Args:
labels: the ground-truth labels for the examples.
predicted_class: the predicted labels for the examples.
target_class: index of the class that is left as non-zero.
Returns:
Tensor containing the precision value.
"""
labels_binary = binarize(labels, target_class)
predicted_class_binary = binarize(predicted_class, target_class)
return tf.compat.v1.metrics.precision(labels_binary, predicted_class_binary)
# TODO: Verify this F1 score is correct.
def get_f1_score(labels, predictions, target_class=None):
"""Compute F1 score of predictions with respect to the labels.
Args:
labels: tensor whose dimensions must match predictions. The ground-truth
labels for the examples.
predictions: tensor of arbitrary dimension. The predicted labels for the
examples.
target_class: int. Index of the class that is left as non-zero.
Returns:
f1_score: scalar float tensor whose dimensions match predictions. The
calculated f1 score.
update_op: operation that updates the f1 score streaming metric.
"""
if target_class:
labels = binarize(labels, target_class)
predictions = binarize(predictions, target_class)
precision, precision_op = tf.compat.v1.metrics.precision(labels, predictions)
recall, recall_op = tf.compat.v1.metrics.recall(labels, predictions)
def compute_f1_score(name):
pr_product = tf.multiply(precision, recall)
return tf.math.divide(
tf.multiply(2.0, pr_product), tf.add(tf.add(precision, recall), 1e-12),
name)
f1_score = compute_f1_score('value')
with ops.control_dependencies([precision_op, recall_op]):
update_op = compute_f1_score('update_op')
return f1_score, update_op
def is_encoded_variant_type(variant_types_tensor, type_to_select):
"""Returns a bool tensor indicating which variant_types match type_to_select.
Args:
variant_types_tensor: Tensor of shape (batch_size, 1) containing
EncodedVariantType.value int64 values. Each element of this tensor should
be a EncodedVariantType.value int64 value indicating the type of the
variant.
type_to_select: EncodedVariantType. The type of variant we want to select.
Returns:
Tensor of shape (batch_size, 1) of type tf.bool. A True value indicates that
the variant_type at that position matched type_to_select. Has a False
otherwise.
"""
return tf.equal(variant_types_tensor,
tf.constant(type_to_select.value, dtype=tf.int64))
# This dictionary contains a mapping from the human readable name of a metric
# function (e.g., Accuracy) and its associated TensorFlow metric function. All
# of the entries here will be stratified by variant_type in eval_metric_fn.
_METRICS_FUNCS_BY_VARIANT_TYPE = {
'Accuracy': tf.compat.v1.metrics.accuracy,
'Precision': tf.compat.v1.metrics.precision,
'Recall': tf.compat.v1.metrics.recall,
'FPs': tf.compat.v1.metrics.false_positives,
'FNs': tf.compat.v1.metrics.false_negatives,
'TPs': tf.compat.v1.metrics.true_positives,
'TNs': tf.compat.v1.metrics.true_negatives,
}
# A set containing the names of the variant types we split our metrics by type
# by. This data structure isn't a dictionary like it's neighbors because
# eval_metric_fn requires special logic to compute the values here associated
# with each of these names.
_METRICS_BY_VARIANT_TYPE = {'All', 'SNPs', 'Indels'}
# This dictionary contains a mapping from the human readable name of a genotype
# class (e.g., Het) and its associated class label (e.g., 1). All of the entries
# here will be stratified by genotype_class in eval_metric_fn.
_METRICS_GENOTYPE_CLASSES = {
'HomRef': 0,
'Het': 1,
'HomVar': 2,
}
# This dictionary contains a mapping from the human readable name of a metric
# function (e.g., Accuracy) and its associated metric function. All
# of the entries here will be stratified by genotype class (e.g., Het) in
# eval_metric_fn.
_METRICS_FUNCS_BY_GENOTYPE_CLASS = {
'Precision': get_class_precision,
'Recall': get_class_recall,
'F1': get_f1_score,
}
def _eval_name(metric_name, stratification_name):
return metric_name + '/' + stratification_name
class EvalMetricOrdering(enum.Enum):
"""Enum capturing whether a better metric should be larger or smaller."""
BIGGER_IS_BETTER = 1
SMALLER_IS_BETTER = 2
def eval_function_metrics(has_variant_types=True):
"""Gets the set of eval_metrics names and their directionality.
Args:
has_variant_types: bool. Will we be providing variant_type information
during eval so that we'll have metrics stratified by variant_type?
Returns:
dict mapping from a metric name string (e.g., "F1/All") and a
EvalMetricOrdering enum indicating whether larger metric values are better
or worse.
"""
names = {_eval_name('F1', 'All'): EvalMetricOrdering.BIGGER_IS_BETTER}
if has_variant_types:
variant_type_names = _METRICS_BY_VARIANT_TYPE
else:
variant_type_names = {'All'}
for m, s in itertools.product(_METRICS_FUNCS_BY_VARIANT_TYPE,
variant_type_names):
names[_eval_name(m, s)] = EvalMetricOrdering.BIGGER_IS_BETTER
for m, s in itertools.product(_METRICS_FUNCS_BY_GENOTYPE_CLASS,
_METRICS_GENOTYPE_CLASSES):
names[_eval_name(m, s)] = EvalMetricOrdering.BIGGER_IS_BETTER
return names
# NB. This includes only a subset of our usual metrics.
# We'll add the rest back in a subsequent change.
def eval_metric_fn(labels, predictions, variant_types):
"""Calculate eval metrics from Tensors, on CPU host.
Args:
labels: the ground-truth labels for the examples.
predictions: the predicted labels for the examples.
variant_types: variant types (int64 of EncodedVariantType.value) as a tensor
of (batch_size,) or None. The types of these variants. If None, no type
specific evals will be performed.
Returns:
A dictionary of string name to metric.
"""
predicted_classes = tf.argmax(input=predictions, axis=1)
metrics = {}
# Add the metrics stratified by variant_type
weights_by_type = {'All': None}
if variant_types is not None:
weights_by_type['SNPs'] = is_encoded_variant_type(
variant_types, tf_utils.EncodedVariantType.SNP)
weights_by_type['Indels'] = is_encoded_variant_type(
variant_types, tf_utils.EncodedVariantType.INDEL)
for metric_name, metric_func in _METRICS_FUNCS_BY_VARIANT_TYPE.items():
for weight_name, weights in weights_by_type.items():
metrics[_eval_name(metric_name, weight_name)] = metric_func(
labels, predicted_classes, weights=weights)
# Add the metrics stratified by predicted class.
for metric_name, metric_func in _METRICS_FUNCS_BY_GENOTYPE_CLASS.items():
for class_name, class_value in _METRICS_GENOTYPE_CLASSES.items():
metrics[_eval_name(metric_name,
class_name)] = metric_func(labels, predicted_classes,
class_value)
# Special case F1/All to avoid a clash between the two different ways that we
# can compute Precision and Recall (e.g., get_class_precision vs.
# tf.compat.v1.metrics.precision.
metrics[_eval_name('F1', 'All')] = get_f1_score(labels, predicted_classes)
logging.info('Metrics are %s', metrics.keys())
# Make sure our metrics are consistent with the expected names from
# eval_function_metrics.
expected_metrics = eval_function_metrics(
has_variant_types=variant_types is not None)
if set(expected_metrics) != set(metrics):
raise AssertionError(
'Bug: actual metrics={} not equal to expected={}'.format(
','.join(metrics), ','.join(expected_metrics)))
return metrics
# The following two classes support loading exponential moving averages into
# their corresponding variables when a checkpoint is loaded. They're called
# as hooks by the Estimators. Note for future work: this is the documented
# way, but someone on the mailing list suggested that using the scaffold_fn
# mechanism might be better.
class LoadEMAHook(tf_estimator.SessionRunHook):
"""Hook to load EMA into their corresponding variables.
This looks for the latest checkpoint in the model dir.
"""
def __init__(self, model_dir, ignore_missing_vars=False):
super(LoadEMAHook, self).__init__()
self._model_dir = model_dir
self._ignore_missing_vars = ignore_missing_vars
def begin(self):
ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay)
variables_to_restore = ema.variables_to_restore()
self._load_ema = slim.assign_from_checkpoint_fn(
tf.train.latest_checkpoint(self._model_dir),
variables_to_restore,
ignore_missing_vars=self._ignore_missing_vars)
def after_create_session(self, sess, coord):
tf.compat.v1.logging.info('Reloading EMA...')
self._load_ema(sess)
class PredictEMAHook(tf_estimator.SessionRunHook):
"""Hook to load EMA into their corresponding variables.
This reads the specified checkpoint.
"""
def __init__(self, checkpoint_path, ignore_missing_vars=False):
super(PredictEMAHook, self).__init__()
self._checkpoint_path = checkpoint_path
self._ignore_missing_vars = ignore_missing_vars
def begin(self):
ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay)
variables_to_restore = ema.variables_to_restore()
self._load_ema = slim.assign_from_checkpoint_fn(
self._checkpoint_path,
variables_to_restore,
ignore_missing_vars=self._ignore_missing_vars)
def after_create_session(self, sess, coord):
tf.compat.v1.logging.info('Reloading EMA...')
self._load_ema(sess)
class DeepVariantModel(object):
"""Base class for models that compute genotype likelihoods from an image.
This class is intended for use anywhere in DeepVariant where we want to train
or evaluate a model that computes genotype likelihoods from a pileup image. A
bit of encapsulation helps us to try new models (beyond inception_v3) and unit
test our code.
The base class cannot be used directly; concrete subclasses actually implement
specific models and all of the associated machinery to create/load/save
models.
Attributes:
name: str. The name of this model, such as `inception_v3`.
pretrained_model_path: str. Path to a root checkpoint where we can start
training the model, if we are not starting from scratch.
supported_dimensions_message: str. A human-readable string containing info
about what image dimensions are supported by this model. E.g., "only
widths between 42 and 189".
use_tpu: bool or None. If True, we are executing the model on a TPU, False
if we are using some other hardware. If None, the execution hardware is
not yet known.
model_dir: str or None. The path to the location where model checkpoint are
being stored. If None, the path hasn't been set yet or is unknown.
"""
def __init__(self, name, pretrained_model_path):
"""Creates a new DeepVariantModel with name and pretrained_model_path.
Args:
name: str. The name of the model. Passed to DeepVariantModel name.
pretrained_model_path: str. A path to a pretrained model to initialize our
network from when starting from the 'model_default'. If None, training
will start from randomly-initialized parameters.
Raises:
ValueError: if any of the arguments is invalid.
"""
if not name:
raise ValueError('Got an empty value for name', name)
self.name = name
self.pretrained_model_path = pretrained_model_path
self.supported_dimensions_message = 'unknown'
self.use_tpu = None
# Set the model_dir to None by default. We capture its actual value during
# a call to make_estimator below.
self.model_dir = None
def construct_scalar_host_call(self,
metric_dict,
model_dir,
prefix='',
record_frequency_in_steps=100):
"""Construct a host call to log scalars when training on TPU.
Args:
metric_dict: A dict of the tensors to be logged.
model_dir: The location to write the summary.
prefix: The prefix (if any) to prepend to the metric names.
record_frequency_in_steps: int; How often should we log our metrics in
step units.
Returns:
A tuple of (function, args_to_be_passed_to_said_function)
"""
# type: (dict, str) -> (function, list)
metric_names = list(metric_dict.keys())
def host_call_fn(global_step, *args):
"""Training host call.
Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/compat/v1/estimator/tpu/TPUEstimator
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: Tensor with shape `[batch]` for the global_step
*args: Remaining tensors to log.
Returns:
List of summary ops to run on the CPU host.
"""
step = global_step[0]
with tf.compat.v2.summary.create_file_writer(
logdir=model_dir, filename_suffix='.host_call').as_default():
with tf.compat.v2.summary.record_if(
lambda: tf.math.equal(step % record_frequency_in_steps, 0)):
for i, name in enumerate(metric_names):
tf.compat.v2.summary.scalar(
name=prefix + name, data=args[i][0], step=step)
return tf.compat.v1.summary.all_v2_summary_ops()
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor = tf.reshape(
tf.compat.v1.train.get_or_create_global_step(), [1])
other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
return host_call_fn, [global_step_tensor] + other_tensors
def _create_warm_start_settings(self, start_from_checkpoint):
"""Create a proper WarmStartSettings based on start_from_checkpoint."""
# If the special value "model_default" was passed, ask the model for
# its default.
if start_from_checkpoint == 'model_default':
start_from_checkpoint = self.pretrained_model_path
# If the path is non-False, use it.
if start_from_checkpoint:
logging.info('Initializing model from checkpoint at %s',
start_from_checkpoint)
excluded_scopes = set()
reader = tf.compat.v1.train.NewCheckpointReader(start_from_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
if tf_utils.model_num_classes(
start_from_checkpoint,
self.n_classes_model_variable) != dv_constants.NUM_CLASSES:
excluded_scopes.update(self.excluded_scopes_for_incompatible_classes)
if FLAGS.allow_warmstart_from_different_num_channels:
excluded_scopes.update(self.excluded_scopes_for_incompatible_channels)
if excluded_scopes:
logging.info(
'The model checkpoint to warm start from has different '
'shapes. If this is in training, we will '
'exclude: %s', excluded_scopes)
vars_to_include = [
v for v in var_to_shape_map.keys()
if not v.startswith(tuple(excluded_scopes))
]
else:
logging.info('The model checkpoint to warm start from should have the '
'same number of classes and same numbers of channels.'
'If this is in training, we will include everything for '
'warm starting....')
vars_to_include = var_to_shape_map.keys()
return tf_estimator.WarmStartSettings(
ckpt_to_initialize_from=start_from_checkpoint,
vars_to_warm_start='|'.join(vars_to_include))
else:
# If warm_start_from is an empty string, specifically set it to None.
logging.vlog(3, 'Initializing model with random parameters')
return None
def make_estimator(self,
batch_size,
model_dir=None,
max_checkpoints_to_keep=100000,
iterations_per_loop=100,
params=None,
unused_device_fn=None,
master='',
use_tpu=False,
start_from_checkpoint=None,
session_config=None,
include_debug_info=False):
"""Returns a new tf.estimator.Estimator object for training or prediction.
The estimator needs to know batch_size. We use the same value for all
of eval, train, and predict. The estimator will automatically save
checkpoints to model_dir and keep the specified number of them. The value
of iterations_per_loop is not critical, and we default to the recommended
value. Some optional arguments are only required for use with TPU.
This function will use self.model_fn and self.use_tpu when constructing the
model specific Estimator object.
Estimators are also sometimes called classifiers.
Args:
batch_size: the batch size to use (for TRAIN, EVAL, and PREDICT modes).
model_dir: an (optional) string directory to use as the model directory.
max_checkpoints_to_keep: an (optional) integer count of saved checkpoints.
iterations_per_loop: an (optional) integer count of log_step_count_steps.
params: an (optional) dictionary of parameters to pass to the Estimator
constructor.
unused_device_fn: a device_fn to pass to RunConfig, if not use_tpu.
master: a string necessary for TPU, pass FLAGS.master through.
use_tpu: boolean. set self.use_tpu if not None.
start_from_checkpoint: string. If not None, initialize model from this
path. According to the current implementation of Estimator, this will
only be used in training. The inference checkpoint is loaded in a
different place.
session_config: a tf.ConfigProto to pass to RunConfig, if not use_tpu.
include_debug_info: from call_variants. If True, PREDICT mode will
include extra info such as logits and prelogits.
Returns:
an object implementing the tf.estimator.Estimator interface (will be a
TPUEstimator if self.use_tpu is True).
"""
if use_tpu is not None:
self.use_tpu = use_tpu
self.include_debug_info = include_debug_info
# Set the model dir of this class to the model_dir passed in here. It's not
# so clean but it appears to be necessary due to the way estimators are
# constructed (i.e., model_dir is set late).
self.model_dir = model_dir
# These flags are exclusive if not None, and 0 means disable.
save_checkpoints_secs = None
save_checkpoints_steps = None
if FLAGS.save_interval_secs >= 0:
save_checkpoints_secs = FLAGS.save_interval_secs
if FLAGS.save_interval_steps >= 0:
save_checkpoints_steps = FLAGS.save_interval_steps
params = params if params is not None else {}
warm_start_from = self._create_warm_start_settings(start_from_checkpoint)
if self.use_tpu:
tpu_cfg=tpu_config.TPUConfig(
iterations_per_loop=iterations_per_loop)
config = tpu_config.RunConfig(
master=master,
evaluation_master=master,
model_dir=model_dir,
log_step_count_steps=iterations_per_loop,
keep_checkpoint_max=max_checkpoints_to_keep,
save_checkpoints_secs=save_checkpoints_secs,
save_checkpoints_steps=save_checkpoints_steps,
save_summary_steps=FLAGS.save_summary_steps,
tpu_config=tpu_cfg)
classifier = tpu_estimator.TPUEstimator(
use_tpu=self.use_tpu,
model_fn=self.model_fn,
config=config,
# TODO: enable setting these independently.
train_batch_size=batch_size,
eval_batch_size=batch_size,
predict_batch_size=batch_size,
params=params,
warm_start_from=warm_start_from,
)
else:
config = tf_estimator.RunConfig(
model_dir=model_dir,
log_step_count_steps=iterations_per_loop,
keep_checkpoint_max=max_checkpoints_to_keep,
# device_fn=device_fn, # Not in tf1.8?
save_checkpoints_secs=save_checkpoints_secs,
save_checkpoints_steps=save_checkpoints_steps,
save_summary_steps=FLAGS.save_summary_steps,
session_config=session_config,
)
# The TPUEstimator interface implicitly adds batch_size to the params
# dict. Do so explicitly here, so that we can use the same model_fn.
params_with_batch_size = {'batch_size': batch_size}
params_with_batch_size.update(params)
classifier = tf_estimator.Estimator(
model_fn=self.model_fn,
config=config,
params=params_with_batch_size,
warm_start_from=warm_start_from)
return classifier
def model_fn(self, features, labels, mode, params):
"""A model_fn satisfying the Estimator API.
Args:
features: a dictionary supplying features.
labels: a tensor of labels.
mode: one of tf.estimator.ModeKeys.{EVAL,TRAIN}
params: a dictionary of parameters.
Returns:
a tf.estimator.EstimatorSpec or tpu_estimator.TPUEstimatorSpec,
depending on self.use_tpu.
"""
raise NotImplementedError
def session_eval_hooks(self):
"""Returns a list of tf.train.SessionRunHook classes.
A typical use case is to provide a hook to load the EMA variables.
These will be instantiated and invoked by
eval_hooks = [
h(model_dir) for h in model.session_eval_hooks()
]
estimator.evaluate(hooks=...).
Note that this is done according to the instructions in
cloud_tpu/models/inception/inception_v3.py. A newer idea is in
tpuestimator-scaffold, but we haven't tried that approach.
"""
return []
def session_predict_hooks(self):
"""Returns a list of tf.train.SessionRunHook classes.
A typical use case is to provide a hook to load the EMA variables.
These will be instantiated and invoked by
predict_hooks = [
h(checkpoint_path) for h in model.session_predict_hooks()
]
estimator.predict(hooks=...).
Note that this is done according to the instructions in
cloud_tpu/models/inception/inception_v3.py. A newer idea is in
tpuestimator-scaffold, but we haven't tried that approach.
"""
return []
def create(self, images, num_classes, is_training):
"""Creates a new model.
Args:
images: A 4-D tensor of (batch_size, height, width, channels) of pileup
images.
num_classes: integer. How many prediction classes are we expecting in
model?
is_training: boolean. Should we setup model for training (True) or for
inference (False).
Returns:
A dictionary, containing string keys mapped to endpoint tensors of this
model. The dictionary must contain a key 'Predictions' that contains the
probability of having each of 'num_classes' classes.
"""
try:
return self._create(images, num_classes, is_training)
except (ValueError, tf.errors.OpError) as e:
if self._is_bad_image_dimension_exception(e):
_, height, width, _ = images.get_shape().as_list()
message = (
'Unsupported image dimensions detected: model {} was given images '
'of w={} x h={} but a TensorFlow exception occurred while building '
'the model, which typically indicates those dimensions are not '
'supported by the model. The supported dimensions for {} are {}'
).format(self.name, width, height, self.name,
self.supported_dimensions_message)
raise UnsupportedImageDimensionsError(message)
else:
raise
def _is_bad_image_dimension_exception(self, exception):
return any(
x in str(exception) for x in ['Negative dimension', 'SpatialSqueeze'])
def _create(self, images, num_classes, is_training):
"""To be overloaded by subclasses to actually create the model."""
raise NotImplementedError
def preprocess_images(self, images):
"""Preprocessing steps needed for this model to process a batch of images.
Args:
images: A (batch_size, height, width, channels) 4-D Tensor of type uint8.
Returns:
A new batch of images, potentially with different dimensions, based on the
input but transformed as necessary to use with this model.
"""
raise NotImplementedError
@property
def is_trainable(self):
"""Returns True if this model can be trained."""
return True
# TODO: Add export to save representation suitable for inference.
def __str__(self):
return 'DeepVariantModel(name={})'.format(self.name)
def variables_to_restore_from_model(self, exclude_scopes=None):
"""Gets the list of model variables that should be restored.
The primary use of this function is to get a subset of tf.Variables from a
slim-defined model that we'd like to restore from a checkpoint. The
checkpoint generally contains all of the variables in the graph during
training, including things like the backprop variables, moving averages for
visualization, etc. Simply restoring all of those variables is brittle, as
we often want to start a new training run, maybe using a different
optimizer, different visualization variables, or replacing part of the model
with a new classification layer, as unneeded variables from the checkpoint
get loaded into the graph and/or new TF variables not present in the graph
cannot be found, raising exceptions. This function allows a clean API to get
just the *model* variables from a graph, excluding all of those non-model
variables, along with optionally removing parts of the model graph via
exclude scopes.
This function calls slim.get_model_variables() to get the raw list of all
variables associated with the MODEL_VARIABLES collection. It then filters
away all variables that match any of the scopes in exclude_scopes. For
example, suppose we have a model with three variables with names:
w1 = model/l1/weight1
w2 = model/l2/weight2
w3 = model/l2/weight3
Without any exclude scopes, we would return these three variables [w1, w2,
and w3]. Providing exclude_scopes=['model/l2'] would return only [w1], while
exclude_scopes=['model/l1'] would return [w2, w3].
Args:
exclude_scopes: None, or a list of strings. Each string is a scope
specification, such as "model/l1" to match all variables whose name
starts with "model/l1".
Returns:
A list of tf.Variable objects.
"""
vars_to_include = slim.get_model_variables()
# We aren't excluding any variables, so just return vars_to_include.
if not exclude_scopes:
return vars_to_include
vars_to_exclude = set()
for scope in exclude_scopes:
vars_to_exclude |= set(slim.get_variables(scope))
return [v for v in vars_to_include if v not in vars_to_exclude]
class DeepVariantSlimModel(DeepVariantModel):
"""Baseclass for DeepVariant models based on Slim networks."""
def __init__(self, name, pretrained_model_path, n_classes_model_variable,
excluded_scopes_for_incompatible_classes,
excluded_scopes_for_incompatible_channels):
"""Creates an DeepVariant CNN network based on a tf.slim model.
Args:
name: see baseclass.
pretrained_model_path: see baseclass.
n_classes_model_variable: str. A fully-qualitified TF variable name in the
model that we can use to determine the shape of the output
classification layer of the model. For example, in inception-v3 from
slim this is 'InceptionV3/Logits/Conv2d_1c_1x1/weights'.
excluded_scopes_for_incompatible_classes: set of str. A set of scopes
that will be excluded when restoring from a checkpoint to avoid loading
incompatible #classes.
excluded_scopes_for_incompatible_channels: set of str. A set of scopes
that will be excluded when restoring from a checkpoint to avoid loading
incompatible #channels.
Raises:
ValueError: If any of the arguments are invalid.
"""
super(DeepVariantSlimModel, self).__init__(
name=name, pretrained_model_path=pretrained_model_path)
self.n_classes_model_variable = n_classes_model_variable
self.excluded_scopes_for_incompatible_classes = (
excluded_scopes_for_incompatible_classes)
self.excluded_scopes_for_incompatible_channels = (
excluded_scopes_for_incompatible_channels)
def preprocess_images(self, images):
"""Applies preprocessing operations for Inception images.
Because this will run in model_fn, on the accelerator, we use operations
that efficiently execute there.
Args:
images: An Tensor of shape [batch_size height, width, channel] with uint8
values.
Returns:
A tensor of images of shape [batch_size height, width, channel]
containing floating point values, with all points rescaled between
-1 and 1 and possibly resized.
"""
images = tf.cast(images, dtype=tf.float32)
images = tf.subtract(images, 128.0)
images = tf.math.divide(images, 128.0)
return images
def model_fn(self, features, labels, mode, params):
"""A model_fn for slim (really inception_v3), satisfying the Estimator API.
Args:
features: a single Tensor or dict of same (from input_fn).
labels: a single Tensor or dict of same (from input_fn).
mode: tf.estimator.ModeKeys.
params: dict.
Returns:
EstimatorSpec or TPUEstimatorSpec depending on self.use_tpu.
"""
# NB. The basic structure of this started from
# //third_party/cloud_tpu/models/inception/inception_v3.py
# TODO: get this from the model.
num_classes = dv_constants.NUM_CLASSES
images = features['image']
images = self.preprocess_images(images)
endpoints = self.create(
images=images,
num_classes=num_classes,
is_training=mode == tf_estimator.ModeKeys.TRAIN)
logits = endpoints['Logits']
predictions = endpoints
predictions.update({
'classes': tf.argmax(input=logits, axis=1, output_type=tf.int32),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
})
prelogits = endpoints['PreLogits'] if self.include_debug_info else None
if mode == tf_estimator.ModeKeys.PREDICT:
return self._model_fn_predict(mode, features, logits, prelogits=prelogits)
# Compute loss.
one_hot_labels = tf.one_hot(labels, num_classes, dtype=tf.int32)
tf.compat.v1.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=logits,
weights=1.0,
label_smoothing=FLAGS.label_smoothing)
total_loss = tf.compat.v1.losses.get_total_loss(
add_regularization_losses=True)
return self.make_ops_and_estimator(features, endpoints, labels, logits,
predictions, total_loss, mode, params)
def make_ops_and_estimator(self, features, endpoints, labels, logits,
predictions, total_loss, mode, params):
"""Make EstimatorSpec for the current model.
Args:
features: a single Tensor or dict of same (from input_fn).
endpoints: a dictionary, containing string keys mapped to endpoint
tensors of this model. The dictionary must contain a key 'Predictions'
that contains the probability of having each of 'num_classes' classes.
labels: a single Tensor or dict of same (from input_fn).
logits: a single Tensor with logits
predictions: A dictionaty that must contain the following keys: 'Logits'
and 'Predictions'.
total_loss: a single Tensor with a loss
mode: tf.estimator.ModeKeys.
params: dict.
Returns:
EstimatorSpec or TPUEstimatorSpec depending on self.use_tpu.
"""
# Note, below, one of train_op or eval_metrics will be None, and the other
# will be populated, depending on mode.
# There are a lot of arguments here; that's to avoid referencing flags in
# leaf functions.
train_op, host_call = self._model_fn_train(
mode=mode,
total_loss=total_loss,
# get() here to be robust when we are in eval mode and batches_per_epoch
# hasn't been provided. In eval mode, model_fn_train will return without
# doing anything.
batches_per_epoch=params.get('batches_per_epoch', None),
num_epochs_per_decay=FLAGS.num_epochs_per_decay,
initial_learning_rate=FLAGS.learning_rate,
learning_rate_decay_factor=FLAGS.learning_rate_decay_factor,
rmsprop_decay=FLAGS.rmsprop_decay,
rmsprop_momentum=FLAGS.rmsprop_momentum,
rmsprop_epsilon=FLAGS.rmsprop_epsilon,
moving_average_decay=FLAGS.moving_average_decay)
eval_metrics = self._model_fn_eval(
mode=mode,
features=features,
labels=labels,
endpoints=endpoints,
logits=logits,
use_logits=False)
spec = tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=total_loss,