# Training and Deploying TensorFlow Models at Scale

In [1]:
import sys

assert sys.version_info >= (3, 7)

In [4]:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tf_keras

In [5]:
from packaging import version
import tensorflow as tf

assert version.parse(tf.__version__) >= version.parse("2.8.0")

In [6]:
if not tf.config.list_physical_devices('GPU'):
    print("No GPU was detected. Neural nets can be very slow without a GPU.")
    if "google.colab" in sys.modules:
        print("Go to Runtime > Change runtime and select a GPU hardware "
              "accelerator.")
    if "kaggle_secrets" in sys.modules:
        print("Go to Settings > Accelerator and select GPU.")

No GPU was detected. Neural nets can be very slow without a GPU.


## Serving a TensorFlow Model

While calling a `predict()` method can be done early on, as the infrastructure grows there comes a point where it is preferable to wrap the model in a small service whose sole role is to make predictions and have the rest of the infrastructure query it (via a REST or gRPC API). This allows easily switching model versions or scaling the service up as needed, perform A/B experiments and ensure that all software components rely on the same model versions.

### Using TensorFlow Serving

In [21]:
from pathlib import Path
import tensorflow as tf

# extra code – load and split the MNIST dataset
mnist = tf.keras.datasets.mnist.load_data()
(X_train_full, y_train_full), (X_test, y_test) = mnist
X_valid, X_train = X_train_full[:5000], X_train_full[5000:]
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]

In [None]:

# extra code – build & train an MNIST model (also handles image preprocessing)
tf.random.set_seed(42)
tf.keras.backend.clear_session()
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=[28, 28], dtype=tf.uint8),
    tf.keras.layers.Rescaling(scale=1 / 255),
    tf.keras.layers.Dense(100, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy",
              optimizer=tf.keras.optimizers.SGD(learning_rate=1e-2),
              metrics=["accuracy"])
model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))

model_name = "my_mnist_model"
model_version = "0001"
model_path = Path(model_name) / model_version
model.save(model_path, save_format="tf")

In [7]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 rescaling (Rescaling)       (None, 784)               0         
                                                                 
 dense (Dense)               (None, 100)               78500     
                                                                 
 dense_1 (Dense)             (None, 10)                1010      
                                                                 
Total params: 79510 (310.59 KB)
Trainable params: 79510 (310.59 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


Let's take a look at the file tree (we've discussed what each of these file is used for in chapter 10):

In [8]:
sorted([str(path) for path in model_path.parent.glob("**/*")])  # extra code

['my_mnist_model/0001',
 'my_mnist_model/0001/assets',
 'my_mnist_model/0001/fingerprint.pb',
 'my_mnist_model/0001/keras_metadata.pb',
 'my_mnist_model/0001/saved_model.pb',
 'my_mnist_model/0001/variables',
 'my_mnist_model/0001/variables/variables.data-00000-of-00001',
 'my_mnist_model/0001/variables/variables.index']

In [9]:
!saved_model_cli show --dir '{model_path}'

The given SavedModel contains the following tag-sets:
'serve'


A SavedModel contains one or more *metagraphs*, which is a computation graph plus some function signature definitions, including their input and output names, types and shapes. Each metagraph is identified by a set of tags.

In [10]:
!saved_model_cli show --dir '{model_path}' --tag_set serve

The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "__saved_model_init_op"
SignatureDef key: "serving_default"


In [11]:
!saved_model_cli show --dir '{model_path}' --tag_set serve \
                      --signature_def serving_default

The given SavedModel SignatureDef contains the following input(s):
  inputs['flatten_input'] tensor_info:
      dtype: DT_UINT8
      shape: (-1, 28, 28)
      name: serving_default_flatten_input:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['dense_1'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 10)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict


### Installing and starting TensorFlow Serving

In [37]:
%%bash --bg
docker pull tensorflow/serving  # downloads the latest TF Serving image

docker run -it --rm -v "/Users/kevinkyhalim/ML_repo/1. Learning/Hands On Machine Learning/my_mnist_model:/models/my_mnist_model" \
    -p 8500:8500 -p 8501:8501 -e MODEL_NAME=my_mnist_model tensorflow/serving

In [42]:
import time

time.sleep(2) # let's wait a couple seconds for the server to start

### Querying TF Serving through the REST API

In [22]:
import json

X_new = X_test[:3]  # pretend we have 3 new digit images to classify
request_json = json.dumps({
    "signature_name": "serving_default",
    "instances": X_new.tolist(),
})

In [14]:
request_json[:100] + "..." + request_json[-10:]

'{"signature_name": "serving_default", "instances": [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..., 0, 0]]]}'

In [None]:
import requests

server_url = "http://localhost:8501/v1/models/my_mnist_model:predict"
response = requests.post(server_url, data=request_json)
response.raise_for_status()  # raise an exception in case of error
response = response.json()

### Querying TF Serving through the gRPC API

(requires a local server to be run, specifically tensorflow/serving)

In [21]:
from tensorflow_serving.apis.predict_pb2 import PredictRequest

model_name = "my_mnist_model"

request = PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = "serving_default"
input_name = model.input_names[0]  # == "flatten_input"
request.inputs[input_name].CopyFrom(tf.make_tensor_proto(X_new))


In [None]:
import grpc
from tensorflow_serving.apis import prediction_service_pb2_grpc

channel = grpc.insecure_channel('localhost:8500')
predict_service = prediction_service_pb2_grpc.PredictionServiceStub(channel)
response = predict_service.Predict(request, timeout=10.0)

In [None]:
output_name = model.output_names[0]
outputs_proto = response.outputs[output_name]
y_proba = tf.make_ndarray(outputs_proto)

In [None]:
y_proba.round(2)

If your client does not include the TensorFlow library, you can convert the response to a NumPy array like this:

In [None]:
# extra code – shows how to avoid using tf.make_ndarray()
output_name = model.output_names[0]
outputs_proto = response.outputs[output_name]
shape = [dim.size for dim in outputs_proto.tensor_shape.dim]
y_proba = np.array(outputs_proto.float_val).reshape(shape)
y_proba.round(2)

### Deploying a new model version

In [24]:
import numpy as np
# extra code – build and train a new MNIST model version
np.random.seed(42)
tf.random.set_seed(42)
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=[28, 28], dtype=tf.uint8),
    tf.keras.layers.Rescaling(scale=1 / 255),
    tf.keras.layers.Dense(50, activation="relu"),
    tf.keras.layers.Dense(50, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy",
              optimizer=tf.keras.optimizers.SGD(learning_rate=1e-2),
              metrics=["accuracy"])
history = model.fit(X_train, y_train, epochs=10,
                    validation_data=(X_valid, y_valid))



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [25]:
model_version = "0002"
model_path = Path(model_name) / model_version
model.save(model_path, save_format="tf")

INFO:tensorflow:Assets written to: my_mnist_model/0002/assets


INFO:tensorflow:Assets written to: my_mnist_model/0002/assets


In [26]:
sorted([str(path) for path in model_path.parent.glob("**/*")])  # extra code

['my_mnist_model/0001',
 'my_mnist_model/0001/assets',
 'my_mnist_model/0001/fingerprint.pb',
 'my_mnist_model/0001/keras_metadata.pb',
 'my_mnist_model/0001/saved_model.pb',
 'my_mnist_model/0001/variables',
 'my_mnist_model/0001/variables/variables.data-00000-of-00001',
 'my_mnist_model/0001/variables/variables.index',
 'my_mnist_model/0002',
 'my_mnist_model/0002/assets',
 'my_mnist_model/0002/fingerprint.pb',
 'my_mnist_model/0002/keras_metadata.pb',
 'my_mnist_model/0002/saved_model.pb',
 'my_mnist_model/0002/variables',
 'my_mnist_model/0002/variables/variables.data-00000-of-00001',
 'my_mnist_model/0002/variables/variables.index']

In [None]:
import requests

server_url = "http://localhost:8501/v1/models/my_mnist_model:predict"
            
response = requests.post(server_url, data=request_json)
response.raise_for_status()
response = response.json()

In [None]:
response.keys()

In [None]:
y_proba = np.array(response["predictions"])
y_proba.round(2)

### Creating a Prediction Service on Vertex AI

Vertex AI is a platform within GCP that offers a wide range of AI-related tools and services. We can upload datasets, get humans to label them, store commonly used features in a feature store and use them for training or in production and train models across many GPU / TPU servers with automatic hyperparameter tuning or model architecture search (AutoML).

In [8]:
project_id = "lucid-bond-463306-k8"  ##### CHANGE THIS TO YOUR PROJECT ID #####

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "lucid-bond-463306-k8-1062e758e5e5.json"

In [10]:
from google.cloud import storage

bucket_name = "mnist_model_bucket"  ##### CHANGE THIS TO A UNIQUE BUCKET NAME #####
location = "us-central1"

storage_client = storage.Client(project=project_id)
# bucket = storage_client.create_bucket(bucket_name, location=location)
bucket = storage_client.bucket(bucket_name)  # to reuse a bucket instead

In [32]:
def upload_directory(bucket, dirpath):
    dirpath = Path(dirpath)
    for filepath in dirpath.glob("**/*"):
        if filepath.is_file():
            blob = bucket.blob(filepath.relative_to(dirpath.parent).as_posix())
            blob.upload_from_filename(filepath)

upload_directory(bucket, "my_mnist_model")

In [27]:
# extra code – a much faster multithreaded implementation of upload_directory()
#              which also accepts a prefix for the target path, and prints stuff

from concurrent import futures

def upload_file(bucket, filepath, blob_path):
    blob = bucket.blob(blob_path)
    blob.upload_from_filename(filepath)

def upload_directory(bucket, dirpath, prefix=None, max_workers=50):
    dirpath = Path(dirpath)
    prefix = prefix or dirpath.name
    with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_filepath = {
            executor.submit(
                upload_file,
                bucket, filepath,
                f"{prefix}/{filepath.relative_to(dirpath).as_posix()}"
            ): filepath
            for filepath in sorted(dirpath.glob("**/*"))
            if filepath.is_file()
        }
        for future in futures.as_completed(future_to_filepath):
            filepath = future_to_filepath[future]
            try:
                result = future.result()
            except Exception as ex:
                print(f"Error uploading {filepath!s:60}: {ex}")  # f!s is str(f)
            else:
                print(f"Uploaded {filepath!s:60}", end="\r")

    print(f"Uploaded {dirpath!s:60}")

In [12]:
from google.cloud import aiplatform

server_image = "gcr.io/cloud-aiplatform/prediction/tf2-gpu.2-8:latest"

aiplatform.init(project=project_id, location=location)
mnist_model = aiplatform.Model.upload(
    display_name="mnist",
    artifact_uri=f"gs://{bucket_name}/my_mnist_model/0001",
    serving_container_image_uri=server_image,
)

Creating Model
Create Model backing LRO: projects/954975238569/locations/us-central1/models/681559770267648000/operations/3461124993486684160
Model created. Resource name: projects/954975238569/locations/us-central1/models/681559770267648000@1
To use this Model in another session:
model = aiplatform.Model('projects/954975238569/locations/us-central1/models/681559770267648000@1')


Create an *endpoint*, where the client applications connect to when they want to access a service, then we need to deploy the model to this endpoint.

In [18]:
endpoint = aiplatform.Endpoint.create(display_name="mnist-endpoint")

endpoint.deploy(
    mnist_model,
    min_replica_count=1,
    # this number will depend on how many accelerator counts are 
    # allowed to have under the specific GPU
    max_replica_count=1,
    machine_type="n1-standard-4",
    accelerator_type="NVIDIA_TESLA_T4",
    accelerator_count=1
)

Creating Endpoint
Create Endpoint backing LRO: projects/954975238569/locations/us-central1/endpoints/8770669014838935552/operations/1401854063871524864
Endpoint created. Resource name: projects/954975238569/locations/us-central1/endpoints/8770669014838935552
To use this Endpoint in another session:
endpoint = aiplatform.Endpoint('projects/954975238569/locations/us-central1/endpoints/8770669014838935552')
Deploying Model projects/954975238569/locations/us-central1/models/681559770267648000 to Endpoint : projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Deploy Endpoint model backing LRO: projects/954975238569/locations/us-central1/endpoints/8770669014838935552/operations/2740549053107404800
Endpoint model deployed. Resource name: projects/954975238569/locations/us-central1/endpoints/8770669014838935552


In [23]:
response = endpoint.predict(instances=X_new.tolist())

In [24]:
import numpy as np

np.round(response.predictions, 2)

array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.97, 0.02, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.98, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]])

In [25]:
endpoint.undeploy_all()  # undeploy all models from the endpoint
endpoint.delete()

Undeploying Endpoint model: projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Undeploy Endpoint model backing LRO: projects/954975238569/locations/us-central1/endpoints/8770669014838935552/operations/7790773085249994752
Endpoint model undeployed. Resource name: projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Deleting Endpoint : projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Endpoint deleted. . Resource name: projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Deleting Endpoint resource: projects/954975238569/locations/us-central1/endpoints/8770669014838935552
Delete Endpoint backing LRO: projects/954975238569/locations/us-central1/operations/6545527788282052608
Endpoint resource projects/954975238569/locations/us-central1/endpoints/8770669014838935552 deleted.


## Running Batch Prediction Jobs on Vertex AI 

In [28]:
batch_path = Path("my_mnist_batch")
batch_path.mkdir(exist_ok=True)
with open(batch_path / "my_mnist_batch.jsonl", "w") as jsonl_file:
    for image in X_test[:100].tolist():
        jsonl_file.write(json.dumps(image))
        jsonl_file.write("\n")

upload_directory(bucket, batch_path)

Uploaded my_mnist_batch                                              


In [29]:
batch_prediction_job = mnist_model.batch_predict(
    job_display_name="my_batch_prediction_job",
    machine_type="n1-standard-4",
    starting_replica_count=1,
    max_replica_count=1,
    accelerator_type="NVIDIA_TESLA_T4",
    accelerator_count=1,
    gcs_source=[f"gs://{bucket_name}/{batch_path.name}/my_mnist_batch.jsonl"],
    gcs_destination_prefix=f"gs://{bucket_name}/my_mnist_predictions/",
    sync=True  # set to False if you don't want to wait for completion
)

Creating BatchPredictionJob
BatchPredictionJob created. Resource name: projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624
To use this BatchPredictionJob in another session:
bpj = aiplatform.BatchPredictionJob('projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624')
View Batch Prediction Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/batch-predictions/6587159748095770624?project=954975238569
BatchPredictionJob projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/954975238569/locations/us-central1/batchPredictionJobs/

In [30]:
batch_prediction_job.output_info  # extra code – shows the output directory

gcs_output_directory: "gs://mnist_model_bucket/my_mnist_predictions/prediction-mnist-2025_06_18T00_19_59_803Z"

In [31]:
y_probas = []
for blob in batch_prediction_job.iter_outputs():
    print(blob.name)  # extra code
    if "prediction.results" in blob.name:
        for line in blob.download_as_text().splitlines():
            y_proba = json.loads(line)["prediction"]
            y_probas.append(y_proba)

my_mnist_predictions/prediction-mnist-2025_06_18T00_19_59_803Z/prediction.errors_stats-00000-of-00001
my_mnist_predictions/prediction-mnist-2025_06_18T00_19_59_803Z/prediction.results-00000-of-00001


In [32]:
y_pred = np.argmax(y_probas, axis=1)
accuracy = np.sum(y_pred == y_test[:100]) / 100

In [33]:
accuracy

np.float64(0.98)

In [34]:
mnist_model.delete()

Deleting Model : projects/954975238569/locations/us-central1/models/681559770267648000
Model deleted. . Resource name: projects/954975238569/locations/us-central1/models/681559770267648000
Deleting Model resource: projects/954975238569/locations/us-central1/models/681559770267648000
Delete Model backing LRO: projects/954975238569/locations/us-central1/models/681559770267648000/operations/7771395292322070528
Model resource projects/954975238569/locations/us-central1/models/681559770267648000 deleted.


Let's delete all the directories we created on GCS (i.e., all the blobs with these prefixes):

In [35]:
for prefix in ["my_mnist_model/", "my_mnist_batch/", "my_mnist_predictions/"]:
    blobs = bucket.list_blobs(prefix=prefix)
    for blob in blobs:
        blob.delete()

#bucket.delete()  # uncomment and run if you want to delete the bucket itself
batch_prediction_job.delete()

Deleting BatchPredictionJob : projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624
BatchPredictionJob deleted. . Resource name: projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624
Deleting BatchPredictionJob resource: projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624
Delete BatchPredictionJob backing LRO: projects/954975238569/locations/us-central1/operations/553250989554008064
BatchPredictionJob resource projects/954975238569/locations/us-central1/batchPredictionJobs/6587159748095770624 deleted.


## Deploying a Model to a Mobile or Embedded Device

*edge computing*, where machine learning models run closer to the source of data, for example in the user's mobile device or in an embedded device.

Advantages include:
1. Allows the device to be smart even when not connected to the internet;
2. Reduces latency by not having to send data to a remote server
3. Reduces the load on the servers
4. May improve privacy since the user's data can stay on the device.

Disadvantages:
1. Computing resources are generally tiny compared to multi-GPU servers;
2. A large model may not fit in the device
3. May use too much RAM and CPU
4. May take too long to download

To overcome this issue, we can leverage the TFLite library that can
1. Reduce the model size, shortening download time and reducing RAM usage
2. Reduce the amount of computations needed for each predictions, reducing latency, battery usage and heating
3. Adapt the model to device-specific constraints



In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(str(model_path))
tflite_model = converter.convert()
with open("my_converted_savedmodel.tflite", "wb") as f:
    f.write(tflite_model)

In [None]:
# extra code – shows how to convert a Keras model
converter = tf.lite.TFLiteConverter.from_keras_model(model)

**Post training Quantization**

Quantizing the weights after training (through symmetrical quantization technique, where it finds the maximum absolute value, *m*, then it maps the floating-point range -m to +m to the fixed-point (integer) range -127 to +127).

Mostly used to reduce the application's size (since anyway the model needs to convert the quantized weights to floats).

In [None]:
converter.optimizations = [tf.lite.Optimize.DEFAULT]

In [None]:
tflite_model = converter.convert()
with open("my_converted_keras_model.tflite", "wb") as f:
    f.write(tflite_model)

Main issue with quantizaton is that it loses a bit of accuracy (similar to adding noise to the weights and activations), if the accuracy drop is too severe, then may need to use *quantization-aware training*, where it adds fake quantization operations to the model so it can learn to ignore the quantization noise during training, enabling the final model weights to be more robust to quantization.

Check out the O’Reilly books **TinyML: Machine Learning with TensorFlow on Arduino and Ultra-Low Power Micro-Controllers**, by Pete Warden (former lead of the TFLite team) and Daniel Situnayake and **AI and Machine Learning for On-Device Development**, by Laurence Moroney.

## Running a Model in a Web Page

Code examples for this section are hosted on glitch.com, a website that lets you create Web apps for free.

* https://homl.info/tfjscode: a simple TFJS Web app that loads a pretrained model and classifies an image.
* https://homl.info/tfjswpa: the same Web app setup as a WPA. Try opening this link on various platforms, including mobile devices.
** https://homl.info/wpacode: this WPA's source code.
* https://tensorflow.org/js: The TFJS library.
** https://www.tensorflow.org/js/demos: some fun demos.

If you want to learn more about TensorFlow.js, check out the O’reilly books **Practical Deep Learning　for Cloud, Mobile, and Edge**, by Anirudh Koul et al., or **Learning TensorFlow.js**, by Gant Laborde.

## Using GPUs to Speed Up Computations

Things to consider when buying a GPU card:
1. Amount of RAM (at least 10GB for image processing / NLP);
2. Bandwidth (speed of sending data into and out of GPU);
3. Number of cores
4. Cooling system

https://timdettmers.com/2023/01/30/which-gpu-for-deep-learning/

### Managing the GPU RAM

1. Assign different programs specific GPU cores;
2. Tell tensroflow to grab only a specific amount of GPU RAM;
3. Tell tensorflow to grab memory only when it needs it
4. Split a GPU into 2 or more logical devices.

### Placing Operations and Variables on Devices

Best practices:
1. Data preprocessing on CPU and neural network operations on the GPU;
2. GPUs have fairly limited communication bandwidth, so avoid unnecessary data transfers into and out of the GPUs;
3. Addming more CPU RAM is simpler and cheaper, compared to increasing a GPU's RAM, so if a variable is not needed in the next few training steps, then it should probably be placed on the CPU.

Note that the CPU is always treated as a single device, even if the machine has multiple CPU cores.

### Parallel Execution Across Multiple Devices

1. We could train several models in parallel, each on its own GPU setting `CUDA_DEVICE_ORDER` and `CUDA_VISIBLE_DEVICES` so that each script only sees a single GPU device. This is great for hyperparameter tuning as you can train in parallel multiple models with different hyperparameters.
2. We can train a model on a single GPU and perform all the preprocessing in parallel on the CPU, using teh dataset's `prefetch()` method to prepare the next few batches in advances so that they are ready when the GPU needs them!
3. If the model takes 2 images as input and processes them using 2 CNNs before joining their outputs, then it will probably run much faster if the CNN is placed on a different GPU.
4. Create an efficient ensemble, placing different trained model on each GPU so that you can get all the predictions much faster to produce the ensemble's final prediction.

## Training Models Across Multiple Devices

### Model Parallelism (model split across devices)

tl;dr May speed up running / training some types of neural networks, and requires special care and tuning!

Model is split across the devices (e.g. training a single neural network across multiple devices by chopping the model into separate chunks and running each chunk on a different device). Effectiveness of this model depends on the architecture of the neural network (fully connected networks will not gain much from this approach). 

For some network architectures such as CNNs, some layers are only partually connected to the lower layers so it's much easier to distribute chunks across devices in an efficient way.

Deep recurrent neural networks can also be split a bit more efficiently across multiple GPUs, but it will take time until all the GPUs will be active (splitting horizontally by placing each layer on a different device).

### Data Parallelism / Single Program, Multiple Data (SPMD)

Have different devices run the same model but different mini-batches and the gradients are computed by avearging the results.

#### Using Mirrored Strategy

Mirroring all model parameters across all the GPUs and applyng the same parameter updates on every GPU. The tricky part is to efficiently compute the mean of all the gradients from all the GPUs and distribute the result across all the GPUs. This can be done through an *AllReduce* algorithm, a class of algorithms where multiple nodes collaborate to efficiently perform a *reduce operation*, while ensuring that all nodes obtain the same final result.

#### Centralized Parameters

Storing the model parameters outside the GPU devices who are performing the computations (*workers*), such as on the CPU. This method allows either synchronous or asynchronous updates.

- Synchronous updates, aggregator waits until all gradients are available before it computes the average gradients and passes them to the optimizer, which will update the model parameters. Once a replica has finished computing its gradients, it must wait for the parameters to be updated before it can proceed to the next mini-batch. The downside is that some devices may be slower than others, so fast devices will have to wait for the slow ones at every step (aka slowest device becomes the bottleneck). This can be overcome by ignoring the gradients from the slowest few replicas (typically ~10%).

- Asynchronous updates, the gradients are immedately used to update the model parameter whenever a replica is finished (no aggregation step). However, it's even more surprising that they work because there is no guarantee that the computed gradients will still be pointing in the right direction as the current parameter gradients (since the parameters that was used to calculate is most likely already outdated once the gradients is applied, leading to *stale gradients*). Several methods to reduce the effect of stale gradients are
    - Reduce the learning rate;
    - Drop the stale gradients or scale them down;
    - Adjust the mini-batch size
    - Start the first few epochs with using just one replica (*warmup phase*) as stale gradients tend to be more damaging at the beginning of training, when gradients are typically large and parameters have not settle into a valley of the cost function yet.

Although it seems that synchronous updates with a few spare replicas was more efficient than using asynchronous updates.

#### Bandwidth saturation

This happens when adding an extra GPU will not improve performance since the time spent moving the data into and out of the GPU will outweigh the speedup obtained by splitting the computation load, and at that point it will actually slow down training.

This issue is more severe for large dense models, since they have a lot of parameters and gradients to transfer.

A method to overcome this is via *pipeline parallelism* (combination of model and data parallelism), achieved by chopping the model into consecutive parts (*stages*), each of which is trained on a different machine. This results in an asynchronous pipeline in which all machines work in parallel with very little idle time. During training, each stage alternates one round of forward propagation, and backpropagation: pulling a mini batch from its input queue, processes it, sends the output to the next stage's input queue, then pulls one mini-batch of gradients from its gradient queue, backpropagates these gradients and updates its own model parameters, and pushed the backpropagated gradients to the previous stage's gradient queue. This, however, will face the issue of stale gradients, and is mitigated by having each stage saving the weights during forward propagation and restoring them during backpropagation to ensure that the same weights are used for both forward pass and the backward pass (*weight stashing*)

Look up *Pathways* by Google, that uses automated model parallelism, asynchronous gang scheduling and other techniques to reach close to 100% hardware utilization across thousands of TPUs!

Realistically, right now, here is what can be done
1. Use few powerful GPUs (rather than weak GPUs)
2. Group GPUs on a few and very well interconnected servers;
3. Drop the float precision (32 bits to 16 buts)
4. If centralized parameters is used, shard / split the parameters across multiple parameter servers (adding more parameter servers will reduce the network load on each server and limit the risk of bandwidth saturation).