In [1]:
import time
import torch
torch.cuda.device_count()

8

In [2]:
batchSize = 64
n = 2048
times = 10
flops = 1e-12 * times * 2 * batchSize*n*n*n

In [3]:
cpu = torch.device('cpu')
devs = [
  #(torch.float32, cpu),
  #(torch.float64, cpu)
]
if torch.cuda.is_available():
  for i in range(torch.cuda.device_count()):
    cuda = torch.device('cuda:{}'.format(i))
    devs.extend([
      (torch.float16, cuda),
      (torch.float32, cuda),
      (torch.float64, cuda)
    ])
print([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
warmup = [torch.rand((n, n), dtype=dtype, device=device, requires_grad=True) for dtype, device in devs]
for t in warmup:
    t.cpu()
    del t

['GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti', 'GeForce RTX 2080 Ti']


In [4]:
for dtype, device in devs:
  try:
    torch.cuda.empty_cache()
    t = torch.rand((batchSize, n, n), dtype=dtype, device=device, requires_grad=True) * 1e-2
    eye = torch.eye(n, dtype=dtype, device=device).expand(batchSize, -1, -1)
    loss = torch.tensor(0, dtype=dtype, device=device)
    print('Test on device {} with dtype {}'.format(device, dtype))
    start = time.perf_counter()
    for i in range(times):
        diff = torch.bmm(t, eye).sum(dim=2).sum(dim=1)
        loss += (diff * diff).mean()
        del diff
    loss.cpu()
    elapsed = time.perf_counter() - start
    print('Time passed: {:.5f}s, TFlops: {:.2f}'.format(elapsed, flops / elapsed))
    del t, eye, loss
  except Exception as e:
    print(e)

Test on device cuda:0 with dtype torch.float16
Time passed: 0.73451s, TFlops: 14.97
Test on device cuda:0 with dtype torch.float32
Time passed: 0.77424s, TFlops: 14.20
Test on device cuda:0 with dtype torch.float64
Time passed: 21.46311s, TFlops: 0.51
Test on device cuda:1 with dtype torch.float16
Time passed: 0.62822s, TFlops: 17.50
Test on device cuda:1 with dtype torch.float32
Time passed: 0.78910s, TFlops: 13.93
Test on device cuda:1 with dtype torch.float64
Time passed: 21.62666s, TFlops: 0.51
Test on device cuda:2 with dtype torch.float16
Time passed: 0.59111s, TFlops: 18.60
Test on device cuda:2 with dtype torch.float32
Time passed: 0.77038s, TFlops: 14.27
Test on device cuda:2 with dtype torch.float64
Time passed: 21.59772s, TFlops: 0.51
Test on device cuda:3 with dtype torch.float16
Time passed: 0.65787s, TFlops: 16.71
Test on device cuda:3 with dtype torch.float32
Time passed: 0.77997s, TFlops: 14.10
Test on device cuda:3 with dtype torch.float64
Time passed: 21.51671s, TFlop