In [13]:
import torch
import time

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parameters
# batch_size = 2 ** 20
batch_size = 2 ** 10
input_dim = 4096
output_dim = 4096

dtype = torch.bfloat16
# Model and data
linear = torch.nn.Linear(input_dim, output_dim, dtype=dtype).to(device)
x = torch.randn(batch_size, input_dim, dtype=dtype,device=device)
grad_output = torch.randn(batch_size, output_dim, dtype=dtype, device=device)

# Warm-up
for _ in range(3):
    out = linear(x)
    out.backward(grad_output, retain_graph=True)

# Measure forward time
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
out = linear(x)
end_event.record()

torch.cuda.synchronize()
forward_time = start_event.elapsed_time(end_event)

# Measure backward time
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
out.backward(grad_output, retain_graph=True)
end_event.record()

torch.cuda.synchronize()
backward_time = start_event.elapsed_time(end_event)

# Timing results
backward_ratio = backward_time / forward_time

print(f"Batch Size: {batch_size}")
print(f"Forward Time: {forward_time:.6f} ms")
print(f"Backward Time: {backward_time:.6f} ms")
print(f"Backward/Forward Ratio: {backward_ratio:.2f}")

Batch Size: 1024
Forward Time: 0.232736 ms
Backward Time: 0.342944 ms
Backward/Forward Ratio: 1.47


In [None]:
"""
Batch Size: 1024
Forward Time: 0.232736 ms
Backward Time: 0.342944 ms
Backward/Forward Ratio: 1.47

Batch Size: 1048576
Forward Time: 1838.430542 ms
Backward Time: 2196.389404 ms
Backward/Forward Ratio: 1.19
"""