In [1]:
import torch
from time import perf_counter
from torch.utils import benchmark
from itertools import product

torch.cuda.seed_all()

In [2]:
def landing_step_inplace(
    param: torch.Tensor, grad: torch.Tensor, I: torch.Tensor
) -> torch.Tensor:

    G = grad @ param.t()          # G shape: (n, n)
    G = G - G.t()

    distance = param.t() @ param  
    distance.sub_(I)              # in-place distance = distance - I

    return param - (G @ param + param @ distance)


"""
Polar retraction from pymanopt: https://pymanopt.org/docs/stable/_modules/pymanopt/manifolds/stiefel.html#Stiefel
    def _retraction_polar(self, point, tangent_vector):
        Y = point + tangent_vector
        u, _, vt = np.linalg.svd(Y, full_matrices=False)
        return u @ vt
"""

# def conditioner_step(param: torch.Tensor, grad: torch.Tensor, I: torch.Tensor):
#     cond = torch.matmul(grad.T, grad)
#     cond.add_(I)
#     L, Q = torch.linalg.eigh(cond)
#     L.rsqrt_()
#     polar_factor = Q @ torch.diag(L) @ Q.T
#     return torch.matmul(param, polar_factor)

def conditioner_step(param: torch.Tensor, grad: torch.Tensor, I: torch.Tensor):
    Y = param + grad
    u, _, vt = torch.linalg.svd(Y, full_matrices=False)
    return torch.matmul(u, vt)

In [None]:
device = 'cuda:7'

In [None]:
param = torch.randn(4096, 32, device=device)
n, p = param.shape
I = torch.eye(p, device=param.device)
comp = torch.compile(landing_step_inplace)
for _ in range(10):
    _ = comp(param, param, I)

comp_conditioner = torch.compile(conditioner_step)
for _ in range(10):
    _ = comp_conditioner(param,param, I)

In [5]:


results = []

n_grid = [768, 4096, 11008]
r_grid = [4, 32, 64, 256]
for n, r in product(n_grid, r_grid):
    # label and sub_label are the rows
    # description is the column
    label = 'retraction/landing'
    sub_label = f'[{n}, {r}]'
    param = torch.randn((n, r), device=device)
    I = torch.eye(r, device=param.device)
    results.append(benchmark.Timer(
        stmt='comp_conditioner(param, param, I)',
        setup='from __main__ import comp_conditioner',
        label=label,
        sub_label=sub_label,
        description='retraction',
        globals={'param': param, "I": I}).adaptive_autorange(threshold=0.15, min_run_time=1.0, max_run_time=5.0))
    results.append(benchmark.Timer(
        stmt='comp(param, param, I)',
        setup='from __main__ import comp',
        label=label,
        sub_label=sub_label,
        description='landing',
        globals={'param': param, "I": I}).adaptive_autorange(threshold=0.15, min_run_time=1.0, max_run_time=5.0))

compare = benchmark.Compare(results)
#compare.trim_significant_figures()
compare.print()

[------------ retraction/landing -----------]
                    |  retraction  |  landing
1 threads: ----------------------------------
      [768, 4]      |     507.2    |    167.5
      [768, 32]     |    1160.3    |    100.4
      [768, 64]     |    2178.1    |    102.7
      [768, 256]    |    8317.5    |    109.9
      [4096, 4]     |     541.6    |    180.9
      [4096, 32]    |    1050.3    |    186.4
      [4096, 64]    |    1916.2    |    238.4
      [4096, 256]   |    9889.4    |    539.7
      [11008, 4]    |     606.7    |   1785.8
      [11008, 32]   |    1312.7    |   1233.6
      [11008, 64]   |    2237.7    |   1632.7
      [11008, 256]  |   11653.9    |   3606.4

Times are in microseconds (us).



In [6]:
for i, measurement in enumerate(results):
    print(f"\n---- Measurement {i} details ----")
    print(measurement)
    print("Number of runs per repeat:", measurement.number_per_run)
    print("Times (seconds):", measurement.times)
    print("Median (seconds):", measurement.median)
    print("IQR (seconds):", measurement.iqr)


---- Measurement 0 details ----
<torch.utils.benchmark.utils.common.Measurement object at 0x70e0bd5df610>
retraction/landing: [768, 4]
retraction
setup: from __main__ import comp_conditioner
  Median: 507.22 us
  IQR:    6.37 us (504.86 to 511.22)
  1967 measurements, 1 runs per measurement, 1 thread
Number of runs per repeat: 1
Times (seconds): [0.0007048351690173149, 0.0005112066864967346, 0.0005045570433139801, 0.0005030771717429161, 0.0005342960357666016, 0.000523085705935955, 0.0005134660750627518, 0.0005134269595146179, 0.000519556924700737, 0.000515456311404705, 0.0005209362134337425, 0.0005116164684295654, 0.0005141366273164749, 0.0005128160119056702, 0.0005121259018778801, 0.0005117766559123993, 0.0005052676424384117, 0.0005690157413482666, 0.0005187159404158592, 0.0005131857469677925, 0.0005185771733522415, 0.0005084965378046036, 0.0005048653110861778, 0.0005125869065523148, 0.000506197102367878, 0.0005051456391811371, 0.0005056271329522133, 0.000507027842104435, 0.000501936

In [7]:
# Benchmark takaes care of warm up
param = torch.randn(4096, 32, device=device)
n, p = param.shape
I = torch.eye(p, device=param.device)

t0 = benchmark.Timer(
    stmt='comp_conditioner(param, param, I)',
    setup='from __main__ import comp_conditioner',
    globals={'param': param, "I": I})

t1 = benchmark.Timer(
    stmt='comp(param, param, I)',
    setup='from __main__ import comp',
    globals={'param': param, "I": I})

# Run only once since benchmark module does warm-up for us
print(t0.timeit(1000))
print(t1.timeit(1000))

<torch.utils.benchmark.utils.common.Measurement object at 0x70e0bc3e0a00>
comp_conditioner(param, param, I)
setup: from __main__ import comp_conditioner
  1.06 ms
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x70e0bd570bb0>
comp(param, param, I)
setup: from __main__ import comp
  165.64 us
  1 measurement, 1000 runs , 1 thread


In [8]:
num_iter = 1000

torch.cuda.synchronize(device=device)
start = perf_counter()
for _ in range(num_iter):
    out = comp(param, param, I)
torch.cuda.synchronize(device=device)
end = perf_counter()

print(f"Landing (comp()) execution time: {end - start} seconds")

## Retraction
torch.cuda.synchronize(device=device)
start = perf_counter()
for _ in range(num_iter):
    out = comp_conditioner(param, param, I)
torch.cuda.synchronize(device=device)
end = perf_counter()

print(f"Retraction (comp_conditioner()) execution time: {end - start} seconds")

Landing (comp()) execution time: 0.1880705077201128 seconds
Retraction (comp_conditioner()) execution time: 1.0508270850405097 seconds
