In [3]:
!nvidia-smi
!lspci | grep nvidia

/bin/bash: line 1: nvidia-smi: command not found
/bin/bash: line 1: lspci: command not found


In [1]:
import torch

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

    # Memory info
    print(f"Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Allocated memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"Reserved memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

CUDA available: True
GPU count: 1
Current GPU: 0
GPU name: Tesla T4
Total memory: 15.83 GB
Allocated memory: 0.00 GB
Reserved memory: 0.00 GB


In [2]:
# Install TensorFlow first
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1 (from tensorflow)
  Downloading tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting wheel<1.0,>=0.23.0 (from astunparse>=1.6.0->tensorflow

In [1]:

# You can check your TPU setup in Colab like this:
import jax
print(f"TPU devices: {jax.device_count()}")
print(f"TPU cores: {jax.local_device_count()}")

# Or with TensorFlow:
import tensorflow as tf
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    print(f"Number of TPU cores: {tpu.num_accelerators()}")
except:
    print("No TPU detected")

TPU devices: 1
TPU cores: 1
No TPU detected


In [3]:
import jax
import jax.numpy as jnp

def print_tpu_status():
    print("=" * 50)
    print("JAX TPU STATUS")
    print("=" * 50)

    print(f"JAX version: {jax.__version__}")
    print(f"Backend: {jax.default_backend()}")
    print(f"Devices: {len(jax.devices())}")

    for i, device in enumerate(jax.devices()):
        print(f"\nDevice {i}:")
        print(f"  Platform: {device.platform}")
        print(f"  Kind: {device.device_kind}")
        print(f"  ID: {device.id}")

    # Test computation
    print(f"\nTesting computation...")
    x = jnp.array([1, 2, 3, 4])
    y = jnp.sum(x**2)
    print(f"Test result: {y}")
    print(f"Computed on: {y.device}")  # Remove the () - device is a property, not a method

    print("=" * 50)

print_tpu_status()

# More specific TPU details
print("Additional TPU Info:")
print(f"Device type: {jax.devices()[0].device_kind}")
print(f"Platform: {jax.devices()[0].platform}")

# Test TPU performance
import time

@jax.jit
def benchmark_tpu():
    x = jnp.ones((1000, 1000))
    return jnp.sum(x @ x @ x)  # Matrix multiplications

# Warm up
_ = benchmark_tpu()

# Time the computation
start = time.time()
result = benchmark_tpu()
end = time.time()

print(f"Benchmark result: {result}")
print(f"Time taken: {end - start:.4f} seconds")

# Check available memory on TPU
def get_tpu_memory():
    try:
        # Create a small array to test memory
        test_array = jnp.ones((1000, 1000))

        # Get memory stats
        backend = jax.lib.xla_bridge.get_backend()
        for device in jax.devices():
            try:
                memory_info = backend.buffer_from_pyval(test_array, device).device_buffer.memory_stats()
                print(f"Device {device}: {memory_info}")
            except:
                print(f"Device {device}: Memory info not available")
    except Exception as e:
        print(f"Memory check failed: {e}")

get_tpu_memory()


JAX TPU STATUS
JAX version: 0.5.2
Backend: tpu
Devices: 1

Device 0:
  Platform: tpu
  Kind: TPU v5 lite
  ID: 0

Testing computation...
Test result: 30
Computed on: TPU_0(process=0,(0,0,0,0))
Additional TPU Info:
Device type: TPU v5 lite
Platform: tpu
Benchmark result: 999999799296.0
Time taken: 0.0001 seconds
Device TPU_0(process=0,(0,0,0,0)): Memory info not available


  backend = jax.lib.xla_bridge.get_backend()


In [4]:
import jax
import jax.numpy as jnp

def print_tpu_status():
    print("=" * 50)
    print("JAX TPU STATUS")
    print("=" * 50)

    print(f"JAX version: {jax.__version__}")
    print(f"Backend: {jax.default_backend()}")
    print(f"Devices: {len(jax.devices())}")

    for i, device in enumerate(jax.devices()):
        print(f"\nDevice {i}:")
        print(f"  Platform: {device.platform}")
        print(f"  Kind: {device.device_kind}")
        print(f"  ID: {device.id}")

    # Test computation
    print(f"\nTesting computation...")
    x = jnp.array([1, 2, 3, 4])
    y = jnp.sum(x**2)
    print(f"Test result: {y}")
    print(f"Computed on: {y.device}")  # Fixed: removed ()

    print("=" * 50)

# Updated memory info function with new API
def get_tpu_memory():
    try:
        # Use the new backend API
        backend = jax.extend.backend.get_backend()
        print(f"Backend platform: {backend.platform}")

        # Create a test array
        test_array = jnp.ones((1000, 1000))

        for i, device in enumerate(jax.devices()):
            print(f"Device {i} ({device.device_kind}): Active")

    except Exception as e:
        print(f"Backend info error: {e}")

# Comprehensive status
print_tpu_status()
print("\nBackend Information:")
get_tpu_memory()

# TPU performance test
print("\nTPU Performance Test:")
import time

@jax.jit
def tpu_benchmark():
    x = jnp.ones((2000, 2000))
    return jnp.sum(x @ x)

# Warm up JIT compilation
_ = tpu_benchmark()

# Actual benchmark
start = time.time()
result = tpu_benchmark()
end = time.time()

print(f"Matrix operation result: {result}")
print(f"Computation time: {end - start:.4f} seconds")
print(f"Executed on: {result.device}")

JAX TPU STATUS
JAX version: 0.5.2
Backend: tpu
Devices: 1

Device 0:
  Platform: tpu
  Kind: TPU v5 lite
  ID: 0

Testing computation...
Test result: 30
Computed on: TPU_0(process=0,(0,0,0,0))

Backend Information:
Backend info error: module 'jax' has no attribute 'extend'

TPU Performance Test:
Matrix operation result: 8000000000.0
Computation time: 0.0001 seconds
Executed on: TPU_0(process=0,(0,0,0,0))


In [5]:
import jax
import jax.numpy as jnp

def print_tpu_status():
    print("=" * 50)
    print("JAX TPU STATUS")
    print("=" * 50)

    print(f"JAX version: {jax.__version__}")
    print(f"Backend: {jax.default_backend()}")
    print(f"Devices: {len(jax.devices())}")

    for i, device in enumerate(jax.devices()):
        print(f"\nDevice {i}:")
        print(f"  Platform: {device.platform}")
        print(f"  Kind: {device.device_kind}")
        print(f"  ID: {device.id}")

    # Test computation
    print(f"\nTesting computation...")
    x = jnp.array([1, 2, 3, 4])
    y = jnp.sum(x**2)
    print(f"Test result: {y}")
    print(f"Computed on: {y.device}")

    print("=" * 50)

# Simple backend info without deprecated calls
def get_backend_info():
    print("Backend Information:")
    print(f"Default backend: {jax.default_backend()}")
    print(f"Available backends: {jax.lib.xla_bridge.get_backend().platform}")

    # Device info
    for i, device in enumerate(jax.devices()):
        print(f"Device {i}: {device.device_kind} on {device.platform}")

# Run the status check
print_tpu_status()
print()
get_backend_info()

# TPU performance test
print("\nTPU Performance Test:")
import time

@jax.jit
def tpu_benchmark():
    x = jnp.ones((2000, 2000))
    return jnp.sum(x @ x)

# Warm up JIT compilation
print("Warming up JIT...")
_ = tpu_benchmark()

# Actual benchmark
print("Running benchmark...")
start = time.time()
result = tpu_benchmark()
end = time.time()

print(f"Matrix operation result: {result}")
print(f"Computation time: {end - start:.4f} seconds")
print(f"Executed on: {result.device}")

JAX TPU STATUS
JAX version: 0.5.2
Backend: tpu
Devices: 1

Device 0:
  Platform: tpu
  Kind: TPU v5 lite
  ID: 0

Testing computation...
Test result: 30
Computed on: TPU_0(process=0,(0,0,0,0))

Backend Information:
Default backend: tpu
Available backends: tpu
Device 0: TPU v5 lite on tpu

TPU Performance Test:
Warming up JIT...
Running benchmark...
Matrix operation result: 8000000000.0
Computation time: 0.0003 seconds
Executed on: TPU_0(process=0,(0,0,0,0))
