##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Distributed Training with DTensors


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

DTensor provides a way for you to distribute the training of your model across devices to improve efficiency, reliability and scalability. For more details on DTensor concepts, see [The DTensor Programming Guide](https://www.tensorflow.org/guide/dtensor_overview).

In this tutorial, you will train a Sentiment Analysis model with DTensor. Three distributed training schemes are demonstrated with this example:

 - Data Parallel training, where the training samples are sharded (partitioned) to devices.
 - Model Parallel training, where the model variables are sharded to devices.
 - Spatial Parallel training, where the features of input data are  sharded to devices. (Also known as [Spatial Partitioning](https://cloud.google.com/blog/products/ai-machine-learning/train-ml-models-on-large-images-and-3d-volumes-with-spatial-partitioning-on-cloud-tpus))

The training portion of this tutorial is inspired [A Kaggle guide on Sentiment Analysis](https://www.kaggle.com/code/anasofiauzsoy/yelp-review-sentiment-analysis-tensorflow-tfds/notebook) notebook. To learn about the complete training and evaluation workflow (without DTensor), refer to that notebook.

This tutorial will walk through the following steps:

- First start with some data cleaning to obtain a `tf.data.Dataset` of tokenized sentences and their polarity.

- Next build an MLP model with custom Dense and BatchNorm layers. Use a `tf.Module` to track the inference variables. The model constructor takes additional `Layout` arguments to control the sharding of variables.

- For training, you will first use data parallel training together with `tf.experimental.dtensor`'s checkpoint feature. Then continue with Model Parallel Training and Spatial Parallel Training.

- The final section briefly describes the interaction between `tf.saved_model` and `tf.experimental.dtensor` as of TensorFlow 2.9.


## Setup

DTensor is part of TensorFlow 2.9.0 release.

In [2]:
!pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

Next, import `tensorflow` and `tensorflow.experimental.dtensor`. Then configure TensorFlow to use 8 virtual CPUs.

Even though this example uses CPUs, DTensor works the same way on CPU, GPU or TPU devices.

In [3]:
import tempfile
import numpy as np
import tensorflow_datasets as tfds

import tensorflow as tf

from tensorflow.experimental import dtensor
print('TensorFlow version:', tf.__version__)

2022-12-14 03:45:54.426942: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 03:45:54.427037: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


TensorFlow version: 2.11.0


In [4]:
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(phy_devices[0], [
        tf.config.LogicalDeviceConfiguration(),
    ] * ncpu)

configure_virtual_cpus(8)
DEVICES = [f'CPU:{i}' for i in range(8)]

tf.config.list_logical_devices('CPU')

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:CPU:1', device_type='CPU'),
 LogicalDevice(name='/device:CPU:2', device_type='CPU'),
 LogicalDevice(name='/device:CPU:3', device_type='CPU'),
 LogicalDevice(name='/device:CPU:4', device_type='CPU'),
 LogicalDevice(name='/device:CPU:5', device_type='CPU'),
 LogicalDevice(name='/device:CPU:6', device_type='CPU'),
 LogicalDevice(name='/device:CPU:7', device_type='CPU')]

## Download the dataset

Download the IMDB reviews data set to train the sentiment analysis model.

In [5]:
train_data = tfds.load('imdb_reviews', split='train', shuffle_files=True, batch_size=64)
train_data

<PrefetchDataset element_spec={'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'text': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>

## Prepare the data

First tokenize the text. Here use an extension of one-hot encoding, the `'tf_idf'` mode of `tf.keras.layers.TextVectorization`.

- For the sake of speed, limit the number of tokens to 1200.
- To keep the `tf.Module` simple, run `TextVectorization` as a preprocessing step before the training.

The final result of the data cleaning section is a `Dataset` with the tokenized text as `x` and label as `y`.

**Note**: Running `TextVectorization` as a preprocessing step is **neither a usual practice nor a recommended one** as doing so assumes the training data fits into the client memory, which is not always the case.


In [6]:
text_vectorization = tf.keras.layers.TextVectorization(output_mode='tf_idf', max_tokens=1200, output_sequence_length=None)
text_vectorization.adapt(data=train_data.map(lambda x: x['text']))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [7]:
def vectorize(features):
  return text_vectorization(features['text']), features['label']

train_data_vec = train_data.map(vectorize)
train_data_vec

<MapDataset element_spec=(TensorSpec(shape=(None, 1200), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

## Build a neural network with DTensor

Now build a Multi-Layer Perceptron (MLP) network with `DTensor`. The network will use fully connected Dense and BatchNorm layers.

`DTensor` expands TensorFlow through single-program multi-data (SPMD) expansion of regular TensorFlow Ops according to the `dtensor.Layout` attributes of their input `Tensor` and variables.

Variables of `DTensor` aware layers are `dtensor.DVariable`, and the constructors of `DTensor` aware layer objects take additional `Layout` inputs in addition to the usual layer parameters.

Note: As of TensorFlow 2.9, Keras layers such as `tf.keras.layer.Dense`, and `tf.keras.layer.BatchNormalization` accepts `dtensor.Layout` arguments.  Refer to the [DTensor Keras Integration Tutorial](/tutorials/distribute/dtensor_keras_tutorial) for more information using Keras with DTensor.

### Dense Layer

The following custom Dense layer defines 2 layer variables: $W_{ij}$ is the variable for weights, and $b_i$ is the variable for the biases.

$$
y_j = \sigma(\sum_i x_i W_{ij} + b_j)
$$


### Layout deduction

This result comes from the following observations:

- The preferred DTensor sharding for operands to a matrix dot product $t_j = \sum_i x_i W_{ij}$ is to shard $\mathbf{W}$ and $\mathbf{x}$ the same way along the $i$-axis.

- The preferred DTensor sharding for operands to a matrix sum $t_j + b_j$, is to shard $\mathbf{t}$ and $\mathbf{b}$ the same way along the $j$-axis.


In [8]:
class Dense(tf.Module):

  def __init__(self, input_size, output_size,
               init_seed, weight_layout, activation=None):
    super().__init__()

    random_normal_initializer = tf.function(tf.random.stateless_normal)

    self.weight = dtensor.DVariable(
        dtensor.call_with_layout(
            random_normal_initializer, weight_layout,
            shape=[input_size, output_size],
            seed=init_seed
            ))
    if activation is None:
      activation = lambda x:x
    self.activation = activation
    
    # bias is sharded the same way as the last axis of weight.
    bias_layout = weight_layout.delete([0])

    self.bias = dtensor.DVariable(
        dtensor.call_with_layout(tf.zeros, bias_layout, [output_size]))

  def __call__(self, x):
    y = tf.matmul(x, self.weight) + self.bias
    y = self.activation(y)

    return y

### BatchNorm

A batch normalization layer helps avoid collapsing modes while training. In this case, adding batch normalization layers helps model training avoid producing a model that only produces zeros.

The constructor of the custom `BatchNorm` layer below does not take a `Layout` argument. This is because `BatchNorm` has no layer variables. This still works with DTensor because 'x', the only input to the layer, is already a DTensor that represents the global batch.

Note: With DTensor, the input Tensor 'x' always represents the global batch. Therefore `tf.nn.batch_normalization` is applied to the global batch. This differs from training with `tf.distribute.MirroredStrategy`, where Tensor 'x' only represents the per-replica shard of the batch (the local batch).

In [9]:
class BatchNorm(tf.Module):

  def __init__(self):
    super().__init__()

  def __call__(self, x, training=True):
    if not training:
      # This branch is not used in the Tutorial.
      pass
    mean, variance = tf.nn.moments(x, axes=[0])
    return tf.nn.batch_normalization(x, mean, variance, 0.0, 1.0, 1e-5)

A full featured batch normalization layer (such as `tf.keras.layers.BatchNormalization`) will need Layout arguments for its variables.

In [10]:
def make_keras_bn(bn_layout):
  return tf.keras.layers.BatchNormalization(gamma_layout=bn_layout,
                                            beta_layout=bn_layout,
                                            moving_mean_layout=bn_layout,
                                            moving_variance_layout=bn_layout,
                                            fused=False)

### Putting Layers Together

Next, build a Multi-layer perceptron (MLP) network with the building blocks above.  The diagram below shows the axis relationships between the input `x` and the weight matrices for the two `Dense` layers without any DTensor sharding or replication applied.

<img src="https://www.tensorflow.org/images/dtensor/no_dtensor.png" alt="The input and weight matrices for a non distributed model." class="no-filter">


The output of the first `Dense` layer is passed into the input of the second `Dense` layer (after the `BatchNorm`). Therefore, the preferred DTensor sharding for the output of first `Dense` layer ($\mathbf{W_1}$) and the input of second `Dense` layer ($\mathbf{W_2}$) is to shard $\mathbf{W_1}$ and $\mathbf{W_2}$ the same way along the common axis $\hat{j}$,

$$
\mathsf{Layout}[{W_{1,ij}}; i, j] = \left[\hat{i}, \hat{j}\right] \\
\mathsf{Layout}[{W_{2,jk}}; j, k] = \left[\hat{j}, \hat{k} \right]
$$

Even though the layout deduction shows that the 2 layouts are not independent, for the sake of simplicity of the model interface, `MLP` will take 2 `Layout` arguments, one per Dense layer.

In [11]:
from typing import Tuple

class MLP(tf.Module):

  def __init__(self, dense_layouts: Tuple[dtensor.Layout, dtensor.Layout]):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dense_layouts[0], activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dense_layouts[1])

  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y


The trade-off between correctness in layout deduction constraints and simplicity of API is a common design point of APIs that uses DTensor.
It is also possible to capture the dependency between `Layout`'s with a different API. For example, the `MLPStricter` class creates the `Layout` objects in the constructor.

In [12]:
class MLPStricter(tf.Module):

  def __init__(self, mesh, input_mesh_dim, inner_mesh_dim1, output_mesh_dim):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dtensor.Layout([input_mesh_dim, inner_mesh_dim1], mesh),
        activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dtensor.Layout([inner_mesh_dim1, output_mesh_dim], mesh))


  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y

To make sure the model runs, probe your model with fully replicated layouts and a fully replicated batch of `'x'` input.

In [13]:
WORLD = dtensor.create_mesh([("world", 8)], devices=DEVICES)

model = MLP([dtensor.Layout.replicated(WORLD, rank=2),
             dtensor.Layout.replicated(WORLD, rank=2)])

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x = dtensor.copy_to_mesh(sample_x, dtensor.Layout.replicated(WORLD, rank=2))
print(model(sample_x))

tf.Tensor([[-5.61041546 5.04737568]
 [-7.14075 6.86515808]
 [-3.10483789 1.5816828]
 ...
 [6.87280321 -3.56776118]
 [8.27548695 -5.7091856]
 [-1.98807693 1.71495843]], layout="sharding_specs:unsharded,unsharded, mesh:|world=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7", shape=(64, 2), dtype=float32)


## Moving data to the device

Usually, `tf.data` iterators (and other data fetching methods) yield tensor objects backed by the local host device memory. This data must be transferred to the accelerator device memory that backs DTensor's component tensors.

`dtensor.copy_to_mesh` is unsuitable for this situation because it replicates input tensors to all devices due to DTensor's global perspective. So in this tutorial, you will use a helper function `repack_local_tensor`, to facilitate the transfer of data. This helper function uses `dtensor.pack` to send (and only send) the shard of the global batch that is intended for a replica to the device backing the replica.

This simplified function assumes single-client. Determining the correct way to split the local tensor and the mapping between the pieces of the split and the local devices can be laboring in a multi-client application.

Additional DTensor API to simplify `tf.data` integration is planned, supporting both single-client and multi-client applications. Please stay tuned.

In [14]:
def repack_local_tensor(x, layout):
  """Repacks a local Tensor-like to a DTensor with layout.

  This function assumes a single-client application.
  """
  x = tf.convert_to_tensor(x)
  sharded_dims = []

  # For every sharded dimension, use tf.split to split the along the dimension.
  # The result is a nested list of split-tensors in queue[0].
  queue = [x]
  for axis, dim in enumerate(layout.sharding_specs):
    if dim == dtensor.UNSHARDED:
      continue
    num_splits = layout.shape[axis]
    queue = tf.nest.map_structure(lambda x: tf.split(x, num_splits, axis=axis), queue)
    sharded_dims.append(dim)

  # Now we can build the list of component tensors by looking up the location in
  # the nested list of split-tensors created in queue[0].
  components = []
  for locations in layout.mesh.local_device_locations():
    t = queue[0]
    for dim in sharded_dims:
      split_index = locations[dim]  # Only valid on single-client mesh.
      t = t[split_index]
    components.append(t)

  return dtensor.pack(components, layout)

## Data parallel training

In this section, you will train your MLP model with data parallel training. The following sections will demonstrate model parallel training and spatial parallel training.

Data parallel training is a commonly used scheme for distributed machine learning:

 - Model variables are replicated on N devices each.
 - A global batch is split into N per-replica batches.
 - Each per-replica batch is trained on the replica device.
 - The gradient is reduced before weight up data is collectively performed on all replicas.

Data parallel training provides nearly linear speedup regarding the number of devices.

### Creating a data parallel mesh

A typical data parallelism training loop uses a DTensor `Mesh` that consists of a single `batch` dimension, where each device becomes a replica that receives a shard from the global batch.

<img src="https://www.tensorflow.org/images/dtensor/dtensor_data_para.png" alt="Data parallel mesh" class="no-filter">


The replicated model runs on the replica, therefore the model variables are fully replicated (unsharded).

In [15]:
mesh = dtensor.create_mesh([("batch", 8)], devices=DEVICES)

model = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),
             dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),])


### Packing training data to DTensors

The training data batch should be packed into DTensors sharded along the `'batch'`(first) axis, such that DTensor will evenly distribute the training data to the `'batch'` mesh dimension.

**Note**: In DTensor, the `batch size` always refers to the global batch size. The batch size should be chosen such that it can be divided evenly by the size of the `batch` mesh dimension.

In [16]:
def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x, sample_y = repack_batch(sample_x, sample_y, mesh)

print('x', sample_x[:, 0])
print('y', sample_y)

x tf.Tensor({"CPU:0": [85.6979828 57.1319885 139.655975 ... 260.267944 438.011902 111.089973], "CPU:1": [117.437973 66.6539841 107.915977 ... 146.003967 260.267944 47.6099892], "CPU:2": [136.481964 215.831955 285.659943 ... 355.487915 206.309952 101.567978], "CPU:3": [107.915977 57.1319885 79.3499832 ... 63.4799881 203.135956 371.35791], "CPU:4": [206.309952 73.0019836 34.9139938 ... 82.5239792 44.4359894 69.8279877], "CPU:5": [95.2199783 219.005951 434.837891 ... 98.3939819 95.2199783 345.965912], "CPU:6": [174.569962 282.485931 38.0879898 ... 234.875946 79.3499832 79.3499832], "CPU:7": [215.831955 590.363892 107.915977 ... 238.049942 244.397949 82.5239792]}, layout="sharding_specs:batch, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:

### Training step

This example uses a Stochastic Gradient Descent optimizer with the Custom Training Loop (CTL). Consult the [Custom Training Loop guide](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) and [Walk through](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough) for more information on those topics.

The `train_step` is encapsulated as a `tf.function` to indicate this body is to be traced as a TensorFlow Graph. The body of `train_step` consists of a forward inference pass, a backward gradient pass, and the variable update.

Note that the body of `train_step` does not contain any special DTensor annotations. Instead, `train_step` only contains high-level TensorFlow operations that process the input `x` and `y` from the global view of the input batch and the model. All of the DTensor annotations (`Mesh`, `Layout`) are factored out of the train step.

In [17]:
# Refer to the CTL (custom training loop guide)
@tf.function
def train_step(model, x, y, learning_rate=tf.constant(1e-4)):
  with tf.GradientTape() as tape:
    logits = model(x)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=y))
  parameters = model.trainable_variables
  gradients = tape.gradient(loss, parameters)
  for parameter, parameter_gradient in zip(parameters, gradients):
    parameter.assign_sub(learning_rate * parameter_gradient)

  # Define some metrics
  accuracy = 1.0 - tf.reduce_sum(tf.cast(tf.argmax(logits, axis=-1, output_type=tf.int64) != y, tf.float32)) / x.shape[0]
  loss_per_sample = loss / len(x)
  return {'loss': loss_per_sample, 'accuracy': accuracy}

### Checkpointing

You can checkpoint a DTensor model using `tf.train.Checkpoint` out of the box. Saving and restoring sharded DVariables will perform an efficient sharded save and restore. Currently, when using `tf.train.Checkpoint.save` and `tf.train.Checkpoint.restore`, all DVariables must be on the same host mesh, and DVariables and regular variables cannot be saved together. You can learn more about checkpointing in [this guide](../../guide/checkpoint.ipynb).

When a DTensor checkpoint is restored, `Layout`s of variables can be different from when the checkpoint is saved. That is, saving DTensor models is layout- and mesh-agnostic, and only affects the efficiency of sharded saving. You can save a DTensor model with one mesh and layout and restore it on a different mesh and layout. This tutorial makes use of this feature to continue the training in the Model Parallel training and Spatial Parallel training sections.


In [18]:
CHECKPOINT_DIR = tempfile.mkdtemp()

def start_checkpoint_manager(model):
  ckpt = tf.train.Checkpoint(root=model)
  manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=3)

  if manager.latest_checkpoint:
    print("Restoring a checkpoint")
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
  else:
    print("New training")
  return manager


### Training loop

For the data parallel training scheme, train for epochs and report the progress. 3 epochs is insufficient for training the model -- an accuracy of 50% is as good as randomly guessing.

Enable checkpointing so that you can pick up the training later. In the following section, you will load the checkpoint and train with a different parallel scheme.

In [19]:
num_epochs = 2
manager = start_checkpoint_manager(model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()), stateful_metrics=[])
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:

    x, y = repack_batch(x, y, mesh)

    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

New training


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 6.3831 - accuracy: 0.3750

  1/391 [..............................] - ETA: 5:10 - epoch: 0.0000e+00 - loss: 4.8145 - accuracy: 0.4297

  3/391 [..............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 3.9043 - accuracy: 0.4141 

  4/391 [..............................] - ETA: 17s - epoch: 0.0000e+00 - loss: 3.7511 - accuracy: 0.4313

  6/391 [..............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 3.3543 - accuracy: 0.4464

  7/391 [..............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 3.0973 - accuracy: 0.4551

  9/391 [..............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 2.8797 - accuracy: 0.4625

 11/391 [..............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.7526 - accuracy: 0.4753

 12/391 [..............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 2.6465 - accuracy: 0.4736

 14/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.5242 - accuracy: 0.4771

 15/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.5309 - accuracy: 0.4746

 16/391 [>.............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 2.5012 - accuracy: 0.4807

 18/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.5176 - accuracy: 0.4836

 19/391 [>.............................] - ETA: 16s - epoch: 0.0000e+00 - loss: 2.4740 - accuracy: 0.4859

 21/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.4050 - accuracy: 0.4851

 23/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.3012 - accuracy: 0.4883

 25/391 [>.............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.2060 - accuracy: 0.4940

 27/391 [=>............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.1839 - accuracy: 0.4933

 29/391 [=>............................] - ETA: 15s - epoch: 0.0000e+00 - loss: 2.1578 - accuracy: 0.4948

 31/391 [=>............................] - ETA: 14s - epoch: 0.0000e+00 - loss: 2.1469 - accuracy: 0.5000

 33/391 [=>............................] - ETA: 14s - epoch: 0.0000e+00 - loss: 2.0644 - accuracy: 0.5055

 35/391 [=>............................] - ETA: 14s - epoch: 0.0000e+00 - loss: 2.0063 - accuracy: 0.5082

 37/391 [=>............................] - ETA: 14s - epoch: 0.0000e+00 - loss: 1.9800 - accuracy: 0.5090

 39/391 [=>............................] - ETA: 14s - epoch: 0.0000e+00 - loss: 1.9468 - accuracy: 0.5070

 41/391 [==>...........................] - ETA: 14s - epoch: 0.0000e+00 - loss: 1.9296 - accuracy: 0.5045

 43/391 [==>...........................] - ETA: 14s - epoch: 0.0000e+00 - loss: 1.9194 - accuracy: 0.5057

 45/391 [==>...........................] - ETA: 14s - epoch: 0.0000e+00 - loss: 1.9055 - accuracy: 0.5065

 47/391 [==>...........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8682 - accuracy: 0.5094

 49/391 [==>...........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8244 - accuracy: 0.5122

 51/391 [==>...........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8030 - accuracy: 0.5153

 53/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.7715 - accuracy: 0.5153

 55/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.7773 - accuracy: 0.5162

 57/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8055 - accuracy: 0.5162

 59/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8169 - accuracy: 0.5148

 61/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8152 - accuracy: 0.5144

 63/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8223 - accuracy: 0.5134

 65/391 [===>..........................] - ETA: 13s - epoch: 0.0000e+00 - loss: 1.8102 - accuracy: 0.5152

 67/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.8411 - accuracy: 0.5124

 69/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.8255 - accuracy: 0.5118

 71/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.8167 - accuracy: 0.5119

 73/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.8053 - accuracy: 0.5125

 75/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7870 - accuracy: 0.5138

 77/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.8016 - accuracy: 0.5132

 78/391 [====>.........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7883 - accuracy: 0.5150

 80/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7630 - accuracy: 0.5185

 82/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7391 - accuracy: 0.5198

 84/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7218 - accuracy: 0.5206

 86/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7167 - accuracy: 0.5219

 87/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7148 - accuracy: 0.5206

 89/391 [=====>........................] - ETA: 12s - epoch: 0.0000e+00 - loss: 1.7026 - accuracy: 0.5208

 91/391 [=====>........................] - ETA: 11s - epoch: 0.0000e+00 - loss: 1.6962 - accuracy: 0.5200



















































































































































































































































































































  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.8295 - accuracy: 0.5625

  1/391 [..............................] - ETA: 2:33 - epoch: 1.0000 - loss: 0.9183 - accuracy: 0.5312

  2/391 [..............................] - ETA: 24s - epoch: 1.0000 - loss: 0.9802 - accuracy: 0.5208 

  3/391 [..............................] - ETA: 25s - epoch: 1.0000 - loss: 1.1663 - accuracy: 0.5195

  4/391 [..............................] - ETA: 24s - epoch: 1.0000 - loss: 1.3465 - accuracy: 0.5156

  5/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.4158 - accuracy: 0.5182

  6/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.3362 - accuracy: 0.5268

  7/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.2632 - accuracy: 0.5352

  8/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2339 - accuracy: 0.5382

  9/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.2263 - accuracy: 0.5375

 10/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.2733 - accuracy: 0.5312

 11/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 1.2434 - accuracy: 0.5404

 12/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2063 - accuracy: 0.5493

 13/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1795 - accuracy: 0.5502

 14/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2024 - accuracy: 0.5417

 15/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2203 - accuracy: 0.5361

 16/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2314 - accuracy: 0.5395

 17/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2502 - accuracy: 0.5347

 18/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2873 - accuracy: 0.5354

 19/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2668 - accuracy: 0.5367

 20/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2569 - accuracy: 0.5379

 21/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2720 - accuracy: 0.5412

 22/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2526 - accuracy: 0.5408

 23/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2319 - accuracy: 0.5495

 24/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.2090 - accuracy: 0.5550

 25/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1914 - accuracy: 0.5565

 26/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1983 - accuracy: 0.5561

 27/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1970 - accuracy: 0.5564

 28/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1790 - accuracy: 0.5587

 29/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1730 - accuracy: 0.5578

 30/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1846 - accuracy: 0.5580

 31/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1803 - accuracy: 0.5571

 32/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1999 - accuracy: 0.5568

 33/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1894 - accuracy: 0.5551

 34/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1883 - accuracy: 0.5571

 35/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1731 - accuracy: 0.5629

 36/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1598 - accuracy: 0.5638

 37/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1495 - accuracy: 0.5662

 38/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1524 - accuracy: 0.5665

 39/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 1.1682 - accuracy: 0.5660

 40/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 1.1802 - accuracy: 0.5655

 41/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 1.1757 - accuracy: 0.5655

 42/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 1.1696 - accuracy: 0.5658

 43/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 1.1845 - accuracy: 0.5632

 44/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.2015 - accuracy: 0.5625

 45/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.2076 - accuracy: 0.5615

 46/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1973 - accuracy: 0.5622

 47/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.2013 - accuracy: 0.5618

 48/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1924 - accuracy: 0.5625

 49/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1811 - accuracy: 0.5650

 50/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1694 - accuracy: 0.5686

 51/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1579 - accuracy: 0.5721

 52/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1484 - accuracy: 0.5743

 53/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1389 - accuracy: 0.5752

 54/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1328 - accuracy: 0.5764

 55/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1421 - accuracy: 0.5753

 56/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 1.1627 - accuracy: 0.5735

 57/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1655 - accuracy: 0.5735

 58/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1643 - accuracy: 0.5728

 59/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1686 - accuracy: 0.5719

 60/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1663 - accuracy: 0.5720

 61/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1661 - accuracy: 0.5721

 62/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1723 - accuracy: 0.5709

 63/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1723 - accuracy: 0.5708

 64/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1662 - accuracy: 0.5712

 65/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1615 - accuracy: 0.5717

 66/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1680 - accuracy: 0.5697

 67/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1780 - accuracy: 0.5692

 68/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1745 - accuracy: 0.5691

 69/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1701 - accuracy: 0.5690

 70/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1677 - accuracy: 0.5682

 71/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1665 - accuracy: 0.5686

 72/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1620 - accuracy: 0.5685

 73/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1599 - accuracy: 0.5686

 74/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 1.1536 - accuracy: 0.5700

 75/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1499 - accuracy: 0.5705

 76/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1491 - accuracy: 0.5698

 77/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1523 - accuracy: 0.5707

 78/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1471 - accuracy: 0.5718

 79/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1422 - accuracy: 0.5725

 80/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1358 - accuracy: 0.5741

 81/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1323 - accuracy: 0.5749

 82/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1271 - accuracy: 0.5759

 83/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1235 - accuracy: 0.5757

 84/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1182 - accuracy: 0.5767

 85/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1139 - accuracy: 0.5774

 86/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1138 - accuracy: 0.5772

 87/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1169 - accuracy: 0.5779

 88/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1185 - accuracy: 0.5774

 89/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1213 - accuracy: 0.5773

 90/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1204 - accuracy: 0.5768

 91/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1225 - accuracy: 0.5766

























































































































































































































































































































































































































































































































































































































## Model Parallel Training

If you switch to a 2 dimensional `Mesh`, and shard the model variables along the second mesh dimension, then the training becomes Model Parallel.

In Model Parallel training, each model replica spans multiple devices (2 in this case):

- There are 4 model replicas, and the training data batch is distributed to the 4 replicas.
- The 2 devices within a single model replica receive replicated training data.


<img src="https://www.tensorflow.org/images/dtensor/dtensor_model_para.png" alt="Model parallel mesh" class="no-filter">


In [20]:
mesh = dtensor.create_mesh([("batch", 4), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout([dtensor.UNSHARDED, "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])









As the training data is still sharded along the batch dimension, you can reuse the same `repack_batch` function as the Data Parallel training case. DTensor will automatically replicate the per-replica batch to all devices inside the replica along the `"model"` mesh dimension.

In [21]:
def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

Next run the training loop. The training loop reuses the same checkpoint manager as the Data Parallel training example, and the code looks identical.

You can continue training the data parallel trained model under model parallel training.

In [22]:
num_epochs = 2
manager = start_checkpoint_manager(model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:
    x, y = repack_batch(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))
    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

Restoring a checkpoint


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 0.9760 - accuracy: 0.4844

  1/391 [..............................] - ETA: 4:11 - epoch: 0.0000e+00 - loss: 1.1283 - accuracy: 0.5078

  2/391 [..............................] - ETA: 29s - epoch: 0.0000e+00 - loss: 0.9604 - accuracy: 0.5729 

  3/391 [..............................] - ETA: 29s - epoch: 0.0000e+00 - loss: 1.0287 - accuracy: 0.5742

  4/391 [..............................] - ETA: 27s - epoch: 0.0000e+00 - loss: 1.0962 - accuracy: 0.5656

  5/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 1.1744 - accuracy: 0.5599

  6/391 [..............................] - ETA: 27s - epoch: 0.0000e+00 - loss: 1.1026 - accuracy: 0.5692

  7/391 [..............................] - ETA: 27s - epoch: 0.0000e+00 - loss: 1.0744 - accuracy: 0.5742

  8/391 [..............................] - ETA: 27s - epoch: 0.0000e+00 - loss: 1.0553 - accuracy: 0.5816

  9/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 1.0362 - accuracy: 0.5813

 10/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 1.0463 - accuracy: 0.5739

 11/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 1.0185 - accuracy: 0.5820

 12/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9964 - accuracy: 0.5913

 13/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9788 - accuracy: 0.5960

 14/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9498 - accuracy: 0.6062

 15/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9406 - accuracy: 0.6055

 16/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9377 - accuracy: 0.6085

 17/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.9690 - accuracy: 0.6016

 18/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 1.0026 - accuracy: 0.6012

 19/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9880 - accuracy: 0.6023

 20/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9740 - accuracy: 0.6049

 21/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9729 - accuracy: 0.6072

 22/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9613 - accuracy: 0.6094

 23/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9482 - accuracy: 0.6172

 24/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9372 - accuracy: 0.6212

 25/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9279 - accuracy: 0.6220

 26/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9336 - accuracy: 0.6209

 27/391 [=>............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9282 - accuracy: 0.6228

 28/391 [=>............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9163 - accuracy: 0.6261

 29/391 [=>............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.9079 - accuracy: 0.6281

 30/391 [=>............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.8992 - accuracy: 0.6305

 31/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8956 - accuracy: 0.6304

 32/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.9076 - accuracy: 0.6307

 33/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.9062 - accuracy: 0.6282

 34/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.9083 - accuracy: 0.6286

 35/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8959 - accuracy: 0.6328

 36/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8910 - accuracy: 0.6322

 37/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8868 - accuracy: 0.6340

 38/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8904 - accuracy: 0.6338

 39/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9041 - accuracy: 0.6328

 40/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9170 - accuracy: 0.6315

 41/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9152 - accuracy: 0.6321

 42/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9146 - accuracy: 0.6315

 43/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9256 - accuracy: 0.6293

 44/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9406 - accuracy: 0.6285

 45/391 [==>...........................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9456 - accuracy: 0.6277

 46/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9415 - accuracy: 0.6280

 47/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9475 - accuracy: 0.6279

 48/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9398 - accuracy: 0.6285

 49/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9313 - accuracy: 0.6313

 50/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9232 - accuracy: 0.6330

 51/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9176 - accuracy: 0.6328

 52/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9177 - accuracy: 0.6338

 53/391 [===>..........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9109 - accuracy: 0.6354

 54/391 [===>..........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9060 - accuracy: 0.6355

 55/391 [===>..........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9084 - accuracy: 0.6362

 56/391 [===>..........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9237 - accuracy: 0.6338

 57/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9285 - accuracy: 0.6334

 58/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9294 - accuracy: 0.6337

 59/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9338 - accuracy: 0.6328

 60/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9314 - accuracy: 0.6347

 61/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9315 - accuracy: 0.6343

 62/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9310 - accuracy: 0.6332

 63/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9337 - accuracy: 0.6323

 64/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9291 - accuracy: 0.6344

 65/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9241 - accuracy: 0.6352

 66/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9258 - accuracy: 0.6339

 67/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9336 - accuracy: 0.6328

 68/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9338 - accuracy: 0.6322

 69/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9327 - accuracy: 0.6326

 70/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9349 - accuracy: 0.6316

 71/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9354 - accuracy: 0.6319

 72/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9307 - accuracy: 0.6323

 73/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9280 - accuracy: 0.6330

 74/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9229 - accuracy: 0.6344

 75/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9215 - accuracy: 0.6343

 76/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9180 - accuracy: 0.6347

 77/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9166 - accuracy: 0.6354

 78/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9140 - accuracy: 0.6355

 79/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9108 - accuracy: 0.6363

 80/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9074 - accuracy: 0.6375

 81/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9058 - accuracy: 0.6381

 82/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9019 - accuracy: 0.6386

 83/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8998 - accuracy: 0.6391

 84/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8952 - accuracy: 0.6404

 85/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8912 - accuracy: 0.6417

 86/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8873 - accuracy: 0.6431

 87/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8863 - accuracy: 0.6435

 88/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8899 - accuracy: 0.6420

 89/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8939 - accuracy: 0.6413

 90/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8936 - accuracy: 0.6411

 91/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.8957 - accuracy: 0.6410

























































































































































































































































































































































































































































































































































































































error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:47:19.082827: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:47:19.082985: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:47:19.083252: E 

  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.8002 - accuracy: 0.5625

  1/391 [..............................] - ETA: 2:53 - epoch: 1.0000 - loss: 0.8801 - accuracy: 0.5859

  2/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7753 - accuracy: 0.6458 

  3/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 0.8636 - accuracy: 0.6406

  4/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9272 - accuracy: 0.6281

  5/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9971 - accuracy: 0.6146

  6/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9293 - accuracy: 0.6295

  7/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9288 - accuracy: 0.6270

  8/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9206 - accuracy: 0.6337

  9/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8984 - accuracy: 0.6438

 10/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8901 - accuracy: 0.6449

 11/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8628 - accuracy: 0.6536

 12/391 [..............................] - ETA: 24s - epoch: 1.0000 - loss: 0.8747 - accuracy: 0.6514

 13/391 [..............................] - ETA: 24s - epoch: 1.0000 - loss: 0.8870 - accuracy: 0.6507

 14/391 [>.............................] - ETA: 24s - epoch: 1.0000 - loss: 0.8651 - accuracy: 0.6583

 15/391 [>.............................] - ETA: 24s - epoch: 1.0000 - loss: 0.8584 - accuracy: 0.6562

 16/391 [>.............................] - ETA: 24s - epoch: 1.0000 - loss: 0.8534 - accuracy: 0.6599

 17/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8740 - accuracy: 0.6562

 18/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.9013 - accuracy: 0.6538

 19/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8833 - accuracy: 0.6586

 20/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8668 - accuracy: 0.6637

 21/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8531 - accuracy: 0.6669

 22/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8388 - accuracy: 0.6719

 23/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8270 - accuracy: 0.6771

 24/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8181 - accuracy: 0.6787

 25/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8146 - accuracy: 0.6791

 26/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8180 - accuracy: 0.6811

 27/391 [=>............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8109 - accuracy: 0.6814

 28/391 [=>............................] - ETA: 23s - epoch: 1.0000 - loss: 0.8026 - accuracy: 0.6827

 29/391 [=>............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7975 - accuracy: 0.6839

 30/391 [=>............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7903 - accuracy: 0.6855

 31/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7872 - accuracy: 0.6860

 32/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7933 - accuracy: 0.6856

 33/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7934 - accuracy: 0.6838

 34/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7950 - accuracy: 0.6839

 35/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7836 - accuracy: 0.6888

 36/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7797 - accuracy: 0.6900

 37/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7776 - accuracy: 0.6912

 38/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7805 - accuracy: 0.6911

 39/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7907 - accuracy: 0.6891

 40/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 0.8033 - accuracy: 0.6860

 41/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 0.8024 - accuracy: 0.6860

 42/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 0.8022 - accuracy: 0.6871

 43/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 0.8077 - accuracy: 0.6847

 44/391 [==>...........................] - ETA: 22s - epoch: 1.0000 - loss: 0.8175 - accuracy: 0.6844

 45/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8215 - accuracy: 0.6834

 46/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8171 - accuracy: 0.6845

 47/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8216 - accuracy: 0.6839

 48/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8159 - accuracy: 0.6846

 49/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8078 - accuracy: 0.6866

 50/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.8017 - accuracy: 0.6875

 51/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7983 - accuracy: 0.6869

 52/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7985 - accuracy: 0.6869

 53/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7920 - accuracy: 0.6889

 54/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7887 - accuracy: 0.6892

 55/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7887 - accuracy: 0.6895

 56/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7912 - accuracy: 0.6878

 57/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7960 - accuracy: 0.6870

 58/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7975 - accuracy: 0.6864

 59/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7984 - accuracy: 0.6865

 60/391 [===>..........................] - ETA: 21s - epoch: 1.0000 - loss: 0.7955 - accuracy: 0.6883

 61/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7952 - accuracy: 0.6885

 62/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7925 - accuracy: 0.6885

 63/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7937 - accuracy: 0.6875

 64/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7897 - accuracy: 0.6889

 65/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7849 - accuracy: 0.6896

 66/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7860 - accuracy: 0.6882

 67/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7933 - accuracy: 0.6861

 68/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7951 - accuracy: 0.6855

 69/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7956 - accuracy: 0.6857

 70/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7984 - accuracy: 0.6842

 71/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.8005 - accuracy: 0.6840

 72/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7966 - accuracy: 0.6854

 73/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7948 - accuracy: 0.6856

 74/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7913 - accuracy: 0.6867

 75/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7896 - accuracy: 0.6873

 76/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7871 - accuracy: 0.6883

 77/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7864 - accuracy: 0.6887

 78/391 [====>.........................] - ETA: 20s - epoch: 1.0000 - loss: 0.7842 - accuracy: 0.6891

 79/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7822 - accuracy: 0.6891

 80/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7788 - accuracy: 0.6906

 81/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7757 - accuracy: 0.6915

 82/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7734 - accuracy: 0.6911

 83/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7720 - accuracy: 0.6912

 84/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7686 - accuracy: 0.6915

 85/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7668 - accuracy: 0.6915

 86/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7642 - accuracy: 0.6916

 87/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7635 - accuracy: 0.6912

 88/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7620 - accuracy: 0.6917

 89/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7593 - accuracy: 0.6927

 90/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7576 - accuracy: 0.6932

 91/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.7556 - accuracy: 0.6941

























































































































































































































































































































































































































































































































































































































error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:47:44.347126: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2022-12-14 03:47:44.347744: W tensorflow/core/common_runtime/process_function_library_runtime.cc:915] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:47:44.350116: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Grap

## Spatial Parallel Training

When training data of very high dimensionality (e.g. a very large image or a video), it may be desirable to shard along the feature dimension. This is called [Spatial Partitioning](https://cloud.google.com/blog/products/ai-machine-learning/train-ml-models-on-large-images-and-3d-volumes-with-spatial-partitioning-on-cloud-tpus), which was first introduced into TensorFlow for training models with large 3-d input samples.

<img src="https://www.tensorflow.org/images/dtensor/dtensor_spatial_para.png" alt="Spatial parallel mesh" class="no-filter">

DTensor also supports this case. The only change you need to do is to create a Mesh that includes a `feature` dimension, and apply the corresponding `Layout`.


In [23]:
mesh = dtensor.create_mesh([("batch", 2), ("feature", 2), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout(["feature", "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])


Shard the input data along the `feature` dimension when packing the input tensors to DTensors. You do this with a slightly different repack function, `repack_batch_for_spt`, where `spt` stands for Spatial Parallel Training.

In [24]:
def repack_batch_for_spt(x, y, mesh):
    # Shard data on feature dimension, too
    x = repack_local_tensor(x, layout=dtensor.Layout(["batch", 'feature'], mesh))
    y = repack_local_tensor(y, layout=dtensor.Layout(["batch"], mesh))
    return x, y

The Spatial parallel training can also continue from a checkpoint created with other parallell training schemes.

In [25]:
num_epochs = 2

manager = start_checkpoint_manager(model)
for epoch in range(num_epochs):
  step = 0
  metrics = {'epoch': epoch}
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))

  for x, y in train_data_vec:
    x, y = repack_batch_for_spt(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

Restoring a checkpoint


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 0.7536 - accuracy: 0.6250

  1/391 [..............................] - ETA: 4:12 - epoch: 0.0000e+00 - loss: 0.7093 - accuracy: 0.6719

2022-12-14 03:47:45.322594: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.323345: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.330796: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.332290: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.332886: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.333433: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.333751: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:47:45.333832: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions


  2/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.6289 - accuracy: 0.7135 

  3/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7070 - accuracy: 0.6992

  4/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7735 - accuracy: 0.6875

  5/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.8479 - accuracy: 0.6693

  6/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7911 - accuracy: 0.6830

  7/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.8112 - accuracy: 0.6719

  8/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.8101 - accuracy: 0.6753

  9/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7871 - accuracy: 0.6828

 10/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7771 - accuracy: 0.6832

 11/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7612 - accuracy: 0.6875

 12/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7818 - accuracy: 0.6815

 13/391 [..............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7907 - accuracy: 0.6797

 14/391 [>.............................] - ETA: 26s - epoch: 0.0000e+00 - loss: 0.7660 - accuracy: 0.6875

 15/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7637 - accuracy: 0.6826

 16/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7607 - accuracy: 0.6847

 17/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7737 - accuracy: 0.6814

 18/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7856 - accuracy: 0.6793

 19/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7689 - accuracy: 0.6852

 20/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7550 - accuracy: 0.6912

 21/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7410 - accuracy: 0.6953

 22/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7303 - accuracy: 0.6997

 23/391 [>.............................] - ETA: 25s - epoch: 0.0000e+00 - loss: 0.7224 - accuracy: 0.7031

 24/391 [>.............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7137 - accuracy: 0.7056

 25/391 [>.............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7114 - accuracy: 0.7061

 26/391 [>.............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7098 - accuracy: 0.7089

 27/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7045 - accuracy: 0.7093

 28/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.6995 - accuracy: 0.7101

 29/391 [=>............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.6961 - accuracy: 0.7099

 30/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6907 - accuracy: 0.7122

 31/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6876 - accuracy: 0.7134

 32/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6905 - accuracy: 0.7121

 33/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6922 - accuracy: 0.7100

 34/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6937 - accuracy: 0.7089

 35/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6834 - accuracy: 0.7140

 36/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6806 - accuracy: 0.7166

 37/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6793 - accuracy: 0.7171

 38/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6795 - accuracy: 0.7188

 39/391 [=>............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.6831 - accuracy: 0.7176

 40/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6940 - accuracy: 0.7161

 41/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6945 - accuracy: 0.7147

 42/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6952 - accuracy: 0.7155

 43/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6998 - accuracy: 0.7124

 44/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7078 - accuracy: 0.7115

 45/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7125 - accuracy: 0.7113

 46/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7096 - accuracy: 0.7131

 47/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7134 - accuracy: 0.7129

 48/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7079 - accuracy: 0.7156

 49/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7019 - accuracy: 0.7169

 50/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6975 - accuracy: 0.7175

 51/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6957 - accuracy: 0.7175

 52/391 [==>...........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6963 - accuracy: 0.7173

 53/391 [===>..........................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.6908 - accuracy: 0.7196

 54/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6894 - accuracy: 0.7193

 55/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6914 - accuracy: 0.7193

 56/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6973 - accuracy: 0.7168

 57/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7034 - accuracy: 0.7161

 58/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7049 - accuracy: 0.7153

 59/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7051 - accuracy: 0.7167

 60/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7026 - accuracy: 0.7185

 61/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7015 - accuracy: 0.7190

 62/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6985 - accuracy: 0.7197

 63/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6986 - accuracy: 0.7200

 64/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6952 - accuracy: 0.7204

 65/391 [===>..........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6911 - accuracy: 0.7216

 66/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6931 - accuracy: 0.7206

 67/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6993 - accuracy: 0.7190

 68/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7007 - accuracy: 0.7181

 69/391 [====>.........................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7025 - accuracy: 0.7179

 70/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7044 - accuracy: 0.7170

 71/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7062 - accuracy: 0.7164

 72/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7034 - accuracy: 0.7173

 73/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7017 - accuracy: 0.7173

 74/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6990 - accuracy: 0.7177

 75/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6983 - accuracy: 0.7185

 76/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6968 - accuracy: 0.7190

 77/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6962 - accuracy: 0.7194

 78/391 [====>.........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6945 - accuracy: 0.7203

 79/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6933 - accuracy: 0.7201

 80/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6907 - accuracy: 0.7209

 81/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6885 - accuracy: 0.7214

 82/391 [=====>........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6880 - accuracy: 0.7206

 83/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6884 - accuracy: 0.7201

 84/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6862 - accuracy: 0.7210

 85/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6863 - accuracy: 0.7206

 86/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6841 - accuracy: 0.7209

 87/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6834 - accuracy: 0.7207

 88/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6822 - accuracy: 0.7205

 89/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6799 - accuracy: 0.7215

 90/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6785 - accuracy: 0.7215

 91/391 [=====>........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6774 - accuracy: 0.7223























































































































































































































































































































































































































































































































































































































2022-12-14 03:48:10.146400: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.148439: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.152455: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.157796: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.158256: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.158695: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.181249: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-14 03:48:10.181304: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions




error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:48:10.570399: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2022-12-14 03:48:10.571046: W tensorflow/core/common_runtime/process_function_library_runtime.cc:915] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:48:10.574271: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph 



  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.6446 - accuracy: 0.7344

  1/391 [..............................] - ETA: 2:41 - epoch: 1.0000 - loss: 0.6031 - accuracy: 0.7578

  2/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 0.5305 - accuracy: 0.7708 

  3/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6011 - accuracy: 0.7461

  4/391 [..............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6807 - accuracy: 0.7312

  5/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7588 - accuracy: 0.7057

  6/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7086 - accuracy: 0.7188

  7/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7376 - accuracy: 0.7070

  8/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7467 - accuracy: 0.7049

  9/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7279 - accuracy: 0.7109

 10/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7144 - accuracy: 0.7131

 11/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7015 - accuracy: 0.7174

 12/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7261 - accuracy: 0.7115

 13/391 [..............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7402 - accuracy: 0.7087

 14/391 [>.............................] - ETA: 23s - epoch: 1.0000 - loss: 0.7202 - accuracy: 0.7125

 15/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7163 - accuracy: 0.7051

 16/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7135 - accuracy: 0.7068

 17/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7257 - accuracy: 0.7031

 18/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7379 - accuracy: 0.7007

 19/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7233 - accuracy: 0.7063

 20/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.7109 - accuracy: 0.7128

 21/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6974 - accuracy: 0.7173

 22/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6878 - accuracy: 0.7188

 23/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6804 - accuracy: 0.7220

 24/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6728 - accuracy: 0.7256

 25/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6682 - accuracy: 0.7290

 26/391 [>.............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6654 - accuracy: 0.7321

 27/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6602 - accuracy: 0.7344

 28/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6567 - accuracy: 0.7355

 29/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6518 - accuracy: 0.7359

 30/391 [=>............................] - ETA: 22s - epoch: 1.0000 - loss: 0.6449 - accuracy: 0.7374

 31/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6392 - accuracy: 0.7412

 32/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6368 - accuracy: 0.7420

 33/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6373 - accuracy: 0.7413

 34/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6366 - accuracy: 0.7415

 35/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6270 - accuracy: 0.7457

 36/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6250 - accuracy: 0.7479

 37/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6235 - accuracy: 0.7484

 38/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6234 - accuracy: 0.7496

 39/391 [=>............................] - ETA: 21s - epoch: 1.0000 - loss: 0.6263 - accuracy: 0.7488

 40/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6325 - accuracy: 0.7466

 41/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6333 - accuracy: 0.7455

 42/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6338 - accuracy: 0.7456

 43/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6367 - accuracy: 0.7436

 44/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6408 - accuracy: 0.7424

 45/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6449 - accuracy: 0.7415

 46/391 [==>...........................] - ETA: 21s - epoch: 1.0000 - loss: 0.6424 - accuracy: 0.7430

 47/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6460 - accuracy: 0.7432

 48/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6414 - accuracy: 0.7455

 49/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6364 - accuracy: 0.7466

 50/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6327 - accuracy: 0.7463

 51/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6309 - accuracy: 0.7458

 52/391 [==>...........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6298 - accuracy: 0.7456

 53/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6248 - accuracy: 0.7477

 54/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6240 - accuracy: 0.7474

 55/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6264 - accuracy: 0.7472

 56/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6326 - accuracy: 0.7442

 57/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6376 - accuracy: 0.7435

 58/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6388 - accuracy: 0.7431

 59/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6382 - accuracy: 0.7440

 60/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6354 - accuracy: 0.7462

 61/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7465

 62/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6313 - accuracy: 0.7478

 63/391 [===>..........................] - ETA: 20s - epoch: 1.0000 - loss: 0.6315 - accuracy: 0.7471

 64/391 [===>..........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6288 - accuracy: 0.7478

 65/391 [===>..........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6256 - accuracy: 0.7491

 66/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6286 - accuracy: 0.7479

 67/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6333 - accuracy: 0.7468

 68/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6344 - accuracy: 0.7457

 69/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6369 - accuracy: 0.7453

 70/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6378 - accuracy: 0.7447

 71/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6392 - accuracy: 0.7441

 72/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6373 - accuracy: 0.7449

 73/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6367 - accuracy: 0.7447

 74/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7452

 75/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6338 - accuracy: 0.7463

 76/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7463

 77/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6341 - accuracy: 0.7464

 78/391 [====>.........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6331 - accuracy: 0.7472

 79/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6324 - accuracy: 0.7471

 80/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6300 - accuracy: 0.7471

 81/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6285 - accuracy: 0.7475

 82/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6309 - accuracy: 0.7459

 83/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6335 - accuracy: 0.7448

 84/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6313 - accuracy: 0.7454

 85/391 [=====>........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6298 - accuracy: 0.7462

 86/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6273 - accuracy: 0.7471

 87/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6264 - accuracy: 0.7468

 88/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6259 - accuracy: 0.7461

 89/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6242 - accuracy: 0.7467

 90/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6228 - accuracy: 0.7469

 91/391 [=====>........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6219 - accuracy: 0.7480

























































































































































































































































































































































































































































































































































































































error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:48:36.169982: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2022-12-14 03:48:36.170563: W tensorflow/core/common_runtime/process_function_library_runtime.cc:915] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2022-12-14 03:48:36.175401: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph 



## SavedModel and DTensor

The integration of DTensor and SavedModel is still under development. 

As of TensorFlow `2.11`, `tf.saved_model` can save sharded and replicated DTensor models, and saving will do an efficient sharded save on different devices of the mesh. However, after a model is saved, all DTensor annotations are lost and the saved signatures can only be used with regular Tensors, not DTensors.

In [26]:
mesh = dtensor.create_mesh([("world", 1)], devices=DEVICES[:1])
mlp = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh), 
           dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)])

manager = start_checkpoint_manager(mlp)

model_for_saving = tf.keras.Sequential([
  text_vectorization,
  mlp
])

@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def run(inputs):
  return {'result': model_for_saving(inputs)}

tf.saved_model.save(
    model_for_saving, "/tmp/saved_model",
    signatures=run)

Restoring a checkpoint


INFO:tensorflow:Assets written to: /tmp/saved_model/assets


INFO:tensorflow:Assets written to: /tmp/saved_model/assets


As of TensorFlow 2.9.0, you can only call a loaded signature with a regular Tensor, or a fully replicated DTensor (which will be converted to a regular Tensor).

In [27]:
sample_batch = train_data.take(1).get_single_element()
sample_batch

{'label': <tf.Tensor: shape=(64,), dtype=int64, numpy=
 array([0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1])>,
 'text': <tf.Tensor: shape=(64,), dtype=string, numpy=
 array([b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Wa

In [28]:
loaded = tf.saved_model.load("/tmp/saved_model")

run_sig = loaded.signatures["serving_default"]
result = run_sig(sample_batch['text'])['result']

In [29]:
np.mean(tf.argmax(result, axis=-1) == sample_batch['label'])

0.75

## What's next?

This tutorial demonstrated building and training an MLP sentiment analysis model with DTensor.

Through `Mesh` and `Layout` primitives, DTensor can transform a TensorFlow `tf.function` to a distributed program suitable for a variety of training schemes.

In a real-world machine learning application, evaluation and cross-validation should be applied to avoid producing an over-fitted model. The techniques introduced in this tutorial can also be applied to introduce parallelism to evaluation.

Composing a model with `tf.Module` from scratch is a lot of work, and reusing existing building blocks such as layers and helper functions can drastically speed up model development.
As of TensorFlow 2.9, all Keras Layers under `tf.keras.layers` accepts DTensor layouts as their arguments, and can be used to build DTensor models. You can even directly reuse a Keras model with DTensor without modifying the model implementation. Refer to the [DTensor Keras Integration Tutorial](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial) for information on using DTensor Keras. 