In [1]:
import os
import sys
sys.path.append('./src')

from s3fs import S3FileSystem

import tensorflow as tf
import tensorflow_io as tfio
import horovod.tensorflow as hvd
hvd.init()

from data.datasets import create_dataset
from engine.optimizers import MomentumOptimizer
from engine.schedulers import WarmupScheduler
from smdebug.tensorflow import KerasHook
from smdebug.core.reduction_config import ReductionConfig
from smdebug.core.save_config import SaveConfig
from smdebug.core.collection import CollectionKeys



Extension horovod.torch has not been built: /home/ec2-user/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-38-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.
[2022-06-17 23:23:50.589 ip-172-16-171-119.ec2.internal:8243 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-06-17 23:23:50.636 ip-172-16-171-119.ec2.internal:8243 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


In [2]:
train_tdf = create_dataset("s3://jbsnyder-sagemaker-us-east/data/imagenet/tfrecord/train/", 128)

In [3]:
model = tf.keras.applications.ResNet50V2(weights=None)

In [4]:
scheduler = tf.keras.optimizers.schedules.CosineDecay(.025, 10000)
scheduler = WarmupScheduler(scheduler, 0.01, 250)
opt = MomentumOptimizer(learning_rate=scheduler, momentum=0.9)
loss = tf.keras.losses.CategoricalCrossentropy()

training_params = {'optimizer': opt,
                   'loss': loss,
                   'metrics': ['accuracy'],
                   }

In [5]:
reduction_config = ReductionConfig(['mean'])
save_config = SaveConfig(save_interval=25)
include_collections = [CollectionKeys.LOSSES]

hook_config = {
    'out_dir' : './smdebug/',
    'export_tensorboard': True,
    'tensorboard_dir': './smdebug/tensorboard/',
    'dry_run': False,
    'reduction_config': reduction_config,
    'save_config': save_config,
    'include_regex': None,
    'include_collections': include_collections,
    'save_all': False,
    'include_workers': 'one',
}

smd_hook = KerasHook(**hook_config)

[2022-06-17 22:19:41.891 ip-172-16-171-119.ec2.internal:11503 INFO hook.py:254] Saving to ./smdebug/
[2022-06-17 22:19:41.891 ip-172-16-171-119.ec2.internal:11503 INFO state_store.py:77] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.


In [6]:
model.compile(**training_params)

In [None]:
model.fit(train_tdf, callbacks=[smd_hook])

[2022-06-15 06:21:47.669 ip-172-16-182-210.ec2.internal:12704 INFO hook.py:421] Monitoring the collections: losses, metrics, sm_metrics
    566/Unknown - 599s 1s/step - loss: 6.7382 - accuracy: 0.0062 - batch: 275.0000

In [7]:
data = next(iter(train_tdf))

In [9]:
data[0]

<tf.Tensor: shape=(128, 3, 224, 224), dtype=float32, numpy=
array([[[[ 6.21992350e-01,  6.74132943e-01,  7.11288631e-01, ...,
          -9.24640417e-01, -1.11072648e+00, -1.10904562e+00],
         [ 6.31472468e-01,  6.57992482e-01,  6.67207122e-01, ...,
          -1.03094029e+00, -1.05592501e+00, -1.16160047e+00],
         [ 6.04866982e-01,  7.65309155e-01,  7.69733906e-01, ...,
          -9.74457979e-01, -8.86871576e-01, -1.03263664e+00],
         ...,
         [ 2.16893077e+00,  2.15776896e+00,  2.17953968e+00, ...,
           1.27339363e+00,  1.29099190e+00,  1.29455340e+00],
         [ 2.21082664e+00,  2.19503236e+00,  2.16593480e+00, ...,
           1.27156734e+00,  1.28853250e+00,  1.30949724e+00],
         [ 2.22894549e+00,  2.20412970e+00,  2.15489459e+00, ...,
           1.27970350e+00,  1.25757229e+00,  1.25643945e+00]],

        [[ 8.96708727e-01,  9.50011373e-01,  9.58295763e-01, ...,
          -9.46998060e-01, -1.12503862e+00, -9.82483864e-01],
         [ 8.88893127e-01,  

In [10]:
pred = model(data[0])

In [11]:
pred

<tf.Tensor: shape=(128, 1000), dtype=float32, numpy=
array([[0.0009466 , 0.0009825 , 0.00099284, ..., 0.00099397, 0.00113552,
        0.00098208],
       [0.00089606, 0.00096356, 0.00101398, ..., 0.00096651, 0.00132829,
        0.00098474],
       [0.00091772, 0.00096382, 0.0009931 , ..., 0.00099169, 0.00123101,
        0.00095988],
       ...,
       [0.00089657, 0.00096214, 0.00097533, ..., 0.00099139, 0.00125288,
        0.00096461],
       [0.00092789, 0.00097742, 0.00098778, ..., 0.00099296, 0.00116779,
        0.00097646],
       [0.00091643, 0.00096179, 0.00099195, ..., 0.00099728, 0.0012605 ,
        0.00096166]], dtype=float32)>