<a href="https://colab.research.google.com/github/jemmee/tpu/blob/main/tpu_matmul_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.9.23-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google_pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard~=2.20.0 (from tensorflow)
  Downloading tensorboard-2.20.0-py3-none-any.whl.metadata (1.8 kB)
Collecting wheel<1.0,>=0.23.0 (from astunparse>=1.6.0->tensorflow)
  Downloading wheel-0.45.1-py3-none-any.whl.metadata (2.3 kB)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard~=2.20.0->tensorflow)
  Downloading tensorboard_data_server-0.

In [None]:
import jax
# This will show 'TpuDevice' if everything is set up correctly
print(jax.devices())

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


In [None]:
import jax
try:
    # This should return a list of 'TpuDevice'
    devices = jax.devices()
    print("JAX saw these devices:", devices)

    if any("tpu" in str(d).lower() for d in devices):
        print("✅ TPU is active and working!")
except Exception as e:
    print(f"❌ Hardware error: {e}")

JAX saw these devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
✅ TPU is active and working!


In [None]:
import jax

# Check available devices
devices = jax.devices()
print(f"Connected to {len(devices)} TPU cores:")
for d in devices:
    print(f" - {d}")

Connected to 1 TPU cores:
 - TPU_0(process=0,(0,0,0,0))


In [None]:
import jax.numpy as jnp
import jax

# Create large random matrices on the TPU
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8192, 8192))
y = jax.random.normal(key, (8192, 8192))

# Time the multiplication
%timeit jnp.matmul(x, y).block_until_ready()

7.05 ms ± 28.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
import tensorflow as tf
import os

# 1. Initialize the TPU
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

# 2. Run a calculation inside the 'strategy' scope
with strategy.scope():
    a = tf.random.normal([8192, 8192])
    b = tf.random.normal([8192, 8192])
    start = time.time()
    c = tf.matmul(a, b)
    end = time.time()
    print(f"TF TPU execution time: {end - start:.4f} seconds")