Skip to content

Commit

Permalink
Commented the Colab magic syntax in the TPU benchmark.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfeydy committed Dec 2, 2020
1 parent f0e72d7 commit e266fd2
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions benchmarks/PyTorch_TPU.py
Expand Up @@ -9,17 +9,21 @@


import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

assert os.environ[
"COLAB_TPU_ADDR"
], "Make sure to select TPU from Edit > Notebook settings > Hardware accelerator"

###################################################
#

DIST_BUCKET="gs://tpu-pytorch/wheels"
TORCH_WHEEL="torch-1.15-cp36-cp36m-linux_x86_64.whl"
TORCH_XLA_WHEEL="torch_xla-1.15-cp36-cp36m-linux_x86_64.whl"
TORCHVISION_WHEEL="torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl"
DIST_BUCKET = "gs://tpu-pytorch/wheels"
TORCH_WHEEL = "torch-1.15-cp36-cp36m-linux_x86_64.whl"
TORCH_XLA_WHEEL = "torch_xla-1.15-cp36-cp36m-linux_x86_64.whl"
TORCHVISION_WHEEL = "torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl"

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
"""
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
Expand All @@ -28,7 +32,7 @@
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5

"""

###################################################
#
Expand Down Expand Up @@ -58,24 +62,21 @@
p = torch.randn(N, 1, device=xm.xla_device())

def KP(x, y, p):
D_xx = (x*x).sum(-1).unsqueeze(1) # (N,1)
D_xy = torch.matmul( x, y.permute(1,0) ) # (N,D) @ (D,M) = (N,M)
D_yy = (y*y).sum(-1).unsqueeze(0) # (1,M)
D_xy = D_xx - 2*D_xy + D_yy
D_xx = (x * x).sum(-1).unsqueeze(1) # (N,1)
D_xy = torch.matmul(x, y.permute(1, 0)) # (N,D) @ (D,M) = (N,M)
D_yy = (y * y).sum(-1).unsqueeze(0) # (1,M)
D_xy = D_xx - 2 * D_xy + D_yy
K_xy = (-D_xy).exp()

return K_xy @ p


import time

start = time.time()

for _ in range(nits):
p = KP(x,y,p)
p = KP(x, y, p)

print(p)
end = time.time()
print("Timing with {} points: {} x {:.4f}s".format(N, nits, (end-start) / nits) )



print("Timing with {} points: {} x {:.4f}s".format(N, nits, (end - start) / nits))

0 comments on commit e266fd2

Please sign in to comment.