In [None]:
import time
from itertools import repeat
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda")


# ==================== Polar express ========================

coeffs_list = [
  (8.28721201814563 , -23.595886519098837 , 17.300387312530933) ,
  (4.107059111542203 , -2.9478499167379106 , 0.5448431082926601) ,
  (3.9486908534822946 , -2.908902115962949 , 0.5518191394370137) ,
  (3.3184196573706015 , -2.488488024314874 , 0.51004894012372) ,
  (2.300652019954817 , -1.6689039845747493 , 0.4188073119525673) ,
  (1.891301407787398 , -1.2679958271945868 , 0.37680408948524835) ,
  (1.8750014808534479 , -1.2500016453999487 , 0.3750001645474248) ,
  (1.875 , -1.25 , 0.375) , # subsequent coeffs equal this numerically
  ]

# safety factor for numerical stability (but exclude last polynomial )
coeffs_list = [( a / 1.01 , b / 1.01**3 , c / 1.01**5) for (a , b , c ) in coeffs_list [: -1]] + [ coeffs_list [ -1]]

#@torch.compile
def polar_express( G : torch . Tensor , steps = 6) -> torch . Tensor :
  assert G.ndim >= 2
  X = G.bfloat16() # for speed
  if G.size( -2) > G.size ( -1) : X = X.mT # this reduces FLOPs
  X = X /( X.norm(dim=(-2 , -1), keepdim = True) * 1.01 + 1e-7)
  hs = coeffs_list [:steps] + list(repeat (coeffs_list[-1],steps - len ( coeffs_list ) ) )
  for a , b , c in hs :
    A = X @ X . mT
    B = b * A + c * A @ A
    X = a * X + B @ X # X <- aX + bX ˆ3 + cX ˆ5
  if G.size(-2) > G.size( -1) : X = X.mT
  return X.float()

sizes = [
    (4096, 4096),
    (4096, 16384),
    (16384, 4096),
    (8192, 8192),
    (12288, 12288),
]

def svd_method(W):
    U,S,Vh = torch.linalg.svd(W, full_matrices=False)
    Q = U @ Vh
    nuc = S.sum()
    fro = torch.linalg.norm(W, ord="fro")
    return Q, nuc, fro

def polar_method(W):
    Q = polar_express(W)
    fro = torch.linalg.norm(W, ord="fro")
    nuc = torch.trace(W.T @ Q)
    return Q, nuc, fro

def time_runs(fn, W, repeats=5):
    times = []
    for _ in range(repeats):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        _ = fn(W)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        times.append(t1-t0)
    return np.mean(times), np.std(times)

records = []
for (m,n) in sizes:
    W = torch.randn(m,n,device=device,dtype=torch.float32)

    mean_svd, std_svd   = time_runs(svd_method,   W)
    mean_free, std_free = time_runs(polar_method, W)

    records.append(dict(m=m,n=n,method="SVD",   mean=mean_svd,std=std_svd))
    records.append(dict(m=m,n=n,method="Polar", mean=mean_free,std=std_free))

df = pd.DataFrame(records)
print(df)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# Global font settings
plt.rcParams.update({
    "font.size": 18,       # increase base font size
    "axes.labelsize": 20,  # bigger axis labels
    "axes.titlesize": 22,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "legend.fontsize": 16,
})

df = df.copy()
df["size_label"] = df["m"].astype(str) + "×" + df["n"].astype(str)
df["elements"]   = df["m"] * df["n"]

pivot = (
    df.groupby(["size_label", "elements", "method"])["mean"]
      .mean()
      .unstack("method")
      .sort_index(level="elements")
)

cols  = [c for c in ["SVD", "Polar"] if c in pivot.columns]
pivot = pivot[cols]
sizes = pivot.index.get_level_values("size_label")

k = len(cols)
group_gap    = 0.62
inner_ratio  = 0.06
group_padding_frac = 0.08

usable = group_gap * (1 - group_padding_frac)
bar_w  = usable / (k + inner_ratio * (k - 1))
delta  = bar_w * (1 + inner_ratio)

x = np.arange(len(sizes)) * group_gap
offsets = [(i - (k - 1) / 2) * delta for i in range(k)]

fig, ax = plt.subplots(figsize=(9.2, 5.0))

for i, method in enumerate(cols):
    y = pivot[method].values
    ax.bar(x + offsets[i], y, width=bar_w, linewidth=0, label=method)

ax.set_yscale("log")
ax.yaxis.set_major_formatter(FuncFormatter(lambda v, _: f"{v:g}"))

half_span = ((k - 1) / 2) * delta + bar_w / 2
ax.set_xlim(x.min() - half_span * 1.03, x.max() + half_span * 1.03)

ax.set_xticks(x)
ax.set_xticklabels(sizes, rotation=14, ha="right")
ax.set_ylabel("Avg time (s)")
ax.set_xlabel("Matrix size (m×n)")

for spine in ("top", "right"):
    ax.spines[spine].set_visible(False)

ax.legend(frameon=False, ncol=k, loc="upper left", bbox_to_anchor=(0, 1.02))

if len(cols) == 2:
    y0 = pivot[cols[0]].values.astype(float)
    y1 = pivot[cols[1]].values.astype(float)
    valid = (y0 > 0) & (y1 > 0)
    slower = np.where(y0 >= y1, y0, y1)
    faster = np.where(y0 >= y1, y1, y0)
    speedup = np.full_like(y0, np.nan, dtype=float)
    speedup[valid] = slower[valid] / faster[valid]
    y_top = np.nanmax(np.vstack([y0, y1]), axis=0)
    y_text = y_top * 1.12
    current_bottom, current_top = ax.get_ylim()
    ax.set_ylim(bottom=current_bottom, top=max(np.nanmax(y_text) * 1.08, current_top))
    for xi, yt, s in zip(x, y_text, speedup):
        if np.isfinite(s):
            ax.text(xi, yt, f"×{s:.1f}", ha="center", va="bottom", fontsize=14)

plt.tight_layout()


plt.savefig("timing_comparison.pdf", bbox_inches="tight", pad_inches=0)
plt.show()

