In [2]:
import torch
import apex.amp as amp
import time # for timing execution

# Iterations per test
niter = 100

# Data sizes
N, D_in, D_out = 1024, 8192, 2048

# Results vectors
results_list = []
results_names = []

# Full precision (float32) training
x = torch.randn(N, D_in, device="cuda")
y = torch.randn(N, D_out, device="cuda")
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Warm up
for t in range(10):
    y_pred = model(x)
    loss = torch.nn.functional.mse_loss(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("Running full precision...")
tic = time.perf_counter()
for t in range(niter):
    y_pred = model(x)
    loss = torch.nn.functional.mse_loss(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
toc = time.perf_counter()
results_list.append(toc-tic)
results_names.append("Float32")
print("Full precision: %0.2f seconds" % (toc-tic))

# Training with AMP
x = torch.randn(N, D_in, device="cuda")
y = torch.randn(N, D_out, device="cuda")
for opt_level in ["O0", "O1", "O2", "O3"]:
    model = torch.nn.Linear(D_in, D_out).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    print("Running AMP with opt level %s" % opt_level)

    # Warm up
    for t in range(10):
        y_pred = model(x)
        loss = torch.nn.functional.mse_loss(y_pred, y)
        optimizer.zero_grad()
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()

    tic = time.perf_counter()
    for t in range(niter):
        y_pred = model(x)
        loss = torch.nn.functional.mse_loss(y_pred, y)
        optimizer.zero_grad()
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
    toc = time.perf_counter()
    results_list.append(toc-tic)
    results_names.append("AMP " + opt_level)
    print("AMP (opt level %s): %0.2f seconds" % (opt_level, toc-tic))

# Print results
print("\nResults summary (%d iterations)\n===============" % niter)
for name, result in zip(results_names, results_list):
    print("%s: %0.2f seconds  (%0.2fx full precision speed)" % (name, result, results_list[0]/result))

Running full precision...
Full precision: 5.43 seconds
Selected optimization level O0:  Pure FP32 training.

Defaults for this optimization level are:
enabled                : True
opt_level              : O0
cast_model_type        : torch.float32
patch_torch_functions  : False
keep_batchnorm_fp32    : None
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O0
cast_model_type        : torch.float32
patch_torch_functions  : False
keep_batchnorm_fp32    : None
master_weights         : False
loss_scale             : 1.0
Running AMP with opt level O0
AMP (opt level O0): 5.50 seconds
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : No

# Expected Results
## Jetson AGX Xavier (Python-only `apex` build)
| Precision| Execution time (sec) | Speed-up |
|:----------:|:----------------------:|:----------:|
|   Float32 |        5.43        |   1.00   |
|   AMP O0 |        5.50        |   0.99   |
|   AMP O1 |        3.22        |   1.68   |
|   AMP O2 |        2.89        |   1.88   |
|   AMP O3 |        1.16        |   4.69   |

## Jetson AGX Xavier (Full `apex` build)
| Precision| Execution time (sec) | Speed-up |
|:----------:|:----------------------:|:----------:|
|   Float32 |        5.43        |   1.00   |
|   AMP O0 |        5.50        |   0.99   |
|   AMP O1 |        2.61        |   2.08   |
|   AMP O2 |        2.16        |   2.52   |
|   AMP O3 |        1.15        |   4.72   |

