Skip to content

Commit

Permalink
Merge branch 'fix_cmake_package_TF' into 'master'
Browse files Browse the repository at this point in the history
tnt: optimizers: use `tnt.optimizers.Optimizer` naming scheme

See merge request carpenamarie/hpdlf!215
  • Loading branch information
Alexandra Carpen-Amarie committed Nov 16, 2022
2 parents b349856 + 968f418 commit a1d6350
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 15 deletions.
1 change: 0 additions & 1 deletion cmake/FindTensorflow.cmake
Expand Up @@ -90,7 +90,6 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Tensorflow
REQUIRED_VARS Tensorflow_LIBRARY
Tensorflow_INCLUDE_DIR
Tensorflow_CXX11_ABI_FLAG
VERSION_VAR Tensorflow_VERSION)

mark_as_advanced(Tensorflow_INCLUDE_DIR
Expand Down
50 changes: 50 additions & 0 deletions src/mypy.ini
@@ -0,0 +1,50 @@
[mypy]
disallow_untyped_calls = False
disallow_untyped_defs = False
disallow_incomplete_defs = True
disallow_subclassing_any = False
check_untyped_defs= True

warn_unused_ignores = True
warn_no_return = True
warn_return_any = True
warn_unreachable = True
warn_unused_configs = True

show_none_errors = True
show_error_codes = True

allow_untyped_globals = False
allow_redefinition = False

#strict = True
#implicit_reexport = True # prevent Module has no attribute errors

linecount_report = coverage
linecoverage_report = coverage
lineprecision_report = coverage

exclude = (?x)(
/examples
| /gpi_comm_lib
| /\.mypy
| tarantella/keras/utilities\.py
)

[mypy-tensorflow.*]
ignore_missing_imports = True

[mypy-networkx.*]
ignore_missing_imports = True

[mypy-pygpi]
ignore_missing_imports = True
follow_imports=skip

[mypy-tnt_tfops]
ignore_missing_imports = True
follow_imports=skip

[mypy-GPICommLib]
ignore_missing_imports = True
follow_imports=skip
2 changes: 0 additions & 2 deletions src/tarantella/__init__.py
Expand Up @@ -55,5 +55,3 @@ def is_group_master_rank(group: pygpi.Group) -> bool:
from tarantella.collectives.TensorAllgatherer import TensorAllgatherer

import tarantella.optimizers as optimizers
from tarantella.optimizers.synchronous_distributed_optimizer import Optimizer

3 changes: 1 addition & 2 deletions src/tarantella/keras/model.py
Expand Up @@ -18,8 +18,7 @@ def __call__(cls, *args, **kwargs):
return obj

def _create_tnt_model(cls, model: tf.keras.Model,
parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \
else tnt.ParallelStrategy.DATA,
parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.DATA,
num_pipeline_stages: int = 1):
replica_group = tnt.Group()

Expand Down
4 changes: 2 additions & 2 deletions src/tarantella/keras/models.py
Expand Up @@ -22,8 +22,8 @@ def load_model(filepath, compile = True, **kwargs):
tnt_model = tnt.Model(keras_model, parallel_strategy = tnt.ParallelStrategy.DATA)
if compile:
try:
tnt_optimizer = tnt.distributed_optimizers.SynchDistributedOptimizer(keras_model.optimizer,
group = tnt_model.group)
tnt_optimizer = tnt.optimizers.Optimizer(keras_model.optimizer,
group = tnt_model.group)
tnt_model.dist_optimizer = tnt_optimizer
tnt_model._set_internal_optimizer(tnt_model.dist_optimizer)
tnt_model.compiled = True
Expand Down
2 changes: 1 addition & 1 deletion src/tarantella/optimizers/__init__.py
@@ -1 +1 @@

from tarantella.optimizers.synchronous_distributed_optimizer import Optimizer
Expand Up @@ -22,7 +22,7 @@ def __init__(self, keras_optimizer: tf.keras.optimizers.Optimizer,
name: str = None,
group: tnt.Group = None):
self.keras_optimizer = keras_optimizer
logger.debug(f"[SynchDistributedOptimizer] Initializing generic tnt.Optimizer of type={type(keras_optimizer)}")
logger.debug(f"[SynchDistributedOptimizer] Initializing generic tnt.optimizers.Optimizer of type={type(keras_optimizer)}")
_construct_from_keras_object(self, keras_optimizer)

if name is None:
Expand Down
Expand Up @@ -46,7 +46,7 @@ def compile(self,
elif isinstance(optimizer, str):
config = {'class_name': optimizer, 'config': {}}
optimizer = tf.keras.optimizers.deserialize(config)
self.dist_optimizer = tnt.Optimizer(optimizer, group = self.group)
self.dist_optimizer = tnt.optimizers.Optimizer(optimizer, group = self.group)

kwargs = self._preprocess_compile_kwargs(kwargs)
return self.model.compile(optimizer = self.dist_optimizer,
Expand Down
6 changes: 1 addition & 5 deletions test/python/model_api.py
@@ -1,10 +1,7 @@
from models import mnist_models as mnist
from tarantella.strategy.parallel_strategy import ParallelStrategy
import training_runner as base_runner
import utilities as util
import tarantella as tnt

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
Expand Down Expand Up @@ -168,5 +165,4 @@ def test_optimizer_with_name(self, optimizer_name, optimizer_type):
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
tnt_optimizer = tnt_model.dist_optimizer
assert isinstance(tnt_optimizer, tnt.distributed_optimizers.SynchDistributedOptimizer)
assert isinstance(tnt_optimizer.underlying_optimizer, optimizer_type)
assert isinstance(tnt_optimizer, optimizer_type)

0 comments on commit a1d6350

Please sign in to comment.