<a href="https://colab.research.google.com/github/fyr-repo/parallel_programming_intro/blob/main/matrix_mul_tf_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import time
import os

# Check if TPU is available
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
except:
    print("TPU not available")

All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]


In [2]:
tpu_strategy = tf.distribute.TPUStrategy(resolver)

In [10]:
# Define the dimensions and values of the arrays
rows_cols = 1000
value = 55.55

batch_size = 32



# Create the input arrays filled with the specified value
matrix_a = tf.constant([[value] * rows_cols] * rows_cols, dtype=tf.float32)
matrix_b = tf.constant([[value] * rows_cols] * rows_cols, dtype=tf.float32)

# Perform matrix multiplication on the TPU
# with tf.device('/TPU:0'):
#     result = tf.matmul(matrix_a, matrix_b)

start_time = time.time()

with tpu_strategy.scope():
    result = tf.linalg.matmul(matrix_a, matrix_b)

# result = strategy.run(matmul_fn, args=(matrix_a, matrix_b))

end_time = time.time()

processing_time = end_time - start_time

# Print the result (will be executed on the TPU)
print("Matrix multiplication result:")
print(result)
print(f"Time taken for processing: {processing_time:.4f} seconds")


Matrix multiplication result:
tf.Tensor(
[[3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]
 [3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]
 [3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]
 ...
 [3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]
 [3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]
 [3085808.2 3085808.2 3085808.2 ... 3085808.2 3085808.2 3085808.2]], shape=(1000, 1000), dtype=float32)
Time taken for processing: 0.0481 seconds


In [None]:
# import tensorflow as tf

# # Define the TPU address
# tpu_address = 'grpc://<TPU_IP_ADDRESS>'

# # Initialize the TPU
# tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
# tf.config.experimental_connect_to_cluster(tpu_resolver)
# tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
# tpu_strategy = tf.distribute.TPUStrategy(tpu_resolver)

# # Define matrix sizes
# matrix_size = 1000
# batch_size = 32

# # Create random matrices
# matrix_a = tf.random.normal([batch_size, matrix_size, matrix_size])
# matrix_b = tf.random.normal([batch_size, matrix_size, matrix_size])

# # Wrap the computation inside the TPU strategy scope
# with tpu_strategy.scope():
#     # Create TensorFlow Dataset
#     dataset = tf.data.Dataset.from_tensor_slices((matrix_a, matrix_b)).batch(batch_size)

#     # Define matrix multiplication function
#     @tf.function
#     def matmul_fn(a, b):
#         return tf.linalg.matmul(a, b)

#     # Perform matrix multiplication using distributed training
#     for step, (batch_a, batch_b) in enumerate(dataset):
#         # Perform matrix multiplication
#         result = tpu_strategy.run(matmul_fn, args=(batch_a, batch_b))

#         # Print result or perform further operations
#         print(f"Step {step + 1}: Result shape {result.shape}")

# # Clean up resources
# tf.tpu.experimental.shutdown_tpu_system()