In [1]:
import torch
import csv
import numpy as np
import matplotlib.pyplot as plt

Ensure CUDA is available

In [2]:
assert torch.cuda.is_available(), "CUDA must be available to run this code on GPU."
device = torch.device("cuda")

User-configurable parameters

In [3]:
#dtype = torch.complex128
dtype = torch.float64

Iterate over the batch size

In [6]:
m = 4               # number of rows (fixed)
n = 4               # number of columns (fixed)
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]  # iterate over powers of 2

# Warm-up to initialize CUDA context
warm_up = torch.randn(1, m, n, dtype=dtype, device=device)
U, S, Vh = torch.linalg.svd(warm_up, full_matrices=False)
torch.cuda.synchronize()

# Warm-up to initialize CUDA context
if dtype in [torch.complex64, torch.complex128]:
  # Generate complex data (real + imaginary parts)
  real_part = np.random.randn(1, m, n).astype(np.float32 if dtype == torch.complex64 else np.float64)
  imag_part = np.random.randn(1, m, n).astype(np.float32 if dtype == torch.complex64 else np.float64)
  batch_cpu = real_part + 1j * imag_part
else:
  # Generate real data
  batch_cpu = np.random.randn(1, m, n).astype(np.float32 if dtype == torch.float32 else np.float64)

X = torch.from_numpy(batch_cpu).to(device=device, dtype=dtype)  # shape (B, m, n)
torch.cuda.synchronize()  # Synchronize before starting timing
U, S, Vh = torch.linalg.svd(X, full_matrices=False)
torch.cuda.synchronize()  # Ensure operations complete before measuring

elapsed_times = []
avg_errors    = []

for B in batch_sizes:
    # ── 1) Create a random batch on CPU (NumPy) with specified dtype ──
    if dtype in [torch.complex64, torch.complex128]:
        # Generate complex data (real + imaginary parts)
        real_part = np.random.randn(B, m, n).astype(np.float32 if dtype == torch.complex64 else np.float64)
        imag_part = np.random.randn(B, m, n).astype(np.float32 if dtype == torch.complex64 else np.float64)
        batch_cpu = real_part + 1j * imag_part
    else:
        # Generate real data
        batch_cpu = np.random.randn(B, m, n).astype(np.float32 if dtype == torch.float32 else np.float64)

    # ── 2) Move the batch to GPU as a PyTorch tensor ──
    X = torch.from_numpy(batch_cpu).to(device=device, dtype=dtype)  # shape (B, m, n)

    # ── 3) Time the batched SVD on GPU using torch.cuda.Event ──
    start_event = torch.cuda.Event(enable_timing=True)
    end_event   = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()  # Synchronize before starting timing
    start_event.record()
    U, S, Vh = torch.linalg.svd(X, full_matrices=False)
    end_event.record()
    torch.cuda.synchronize()  # Ensure operations complete before measuring

    elapsed_time = start_event.elapsed_time(end_event) / 1000.0  # Convert milliseconds to seconds
    elapsed_times.append(elapsed_time)

    # ── 4) Move results back to CPU for reconstruction error checking ──
    U_cpu  = U.cpu().numpy()   # shape (B, m, k)
    S_cpu  = S.cpu().numpy()   # shape (B, k)
    Vh_cpu = Vh.cpu().numpy()  # shape (B, k, n)

    # ── 5) Compute reconstruction errors on CPU (NumPy) ──
    errors = []
    k = min(m, n)
    for i in range(B):
        Ai = batch_cpu[i]            # Original matrix

        # For complex: Convert singular values to diagonal matrix in complex form
        if np.iscomplexobj(batch_cpu):
            Si = np.zeros((k, k), dtype=batch_cpu.dtype)
            np.fill_diagonal(Si, S_cpu[i])
        else:
            Si = np.diag(S_cpu[i])       # Real diagonal matrix

        recon_i = U_cpu[i] @ Si @ Vh_cpu[i]  # Reconstructed matrix

        # Calculate Frobenius norm of the difference
        diff = recon_i - Ai
        if np.iscomplexobj(diff):
            # For complex: norm = sqrt(sum(|z|^2))
            err = np.sqrt(np.sum(np.real(diff * np.conj(diff))))
        else:
            err = 100. * np.linalg.norm(diff) / np.linalg.norm(Ai)

        errors.append(err)

    avg_error = np.mean(errors)
    avg_errors.append(avg_error)

    print(f"Batch size B = {B:4d} → elapsed time = {elapsed_time:.6f} s, avg reconstruction error = {avg_error:.3e}")

Batch size B =   32 → elapsed time = 0.033380 s, avg reconstruction error = 1.570e-13
Batch size B =   64 → elapsed time = 0.002306 s, avg reconstruction error = 1.723e-13
Batch size B =  128 → elapsed time = 0.004438 s, avg reconstruction error = 1.579e-13
Batch size B =  256 → elapsed time = 0.007315 s, avg reconstruction error = 1.638e-13
Batch size B =  512 → elapsed time = 0.013620 s, avg reconstruction error = 1.589e-13
Batch size B = 1024 → elapsed time = 0.026597 s, avg reconstruction error = 1.617e-13
Batch size B = 2048 → elapsed time = 0.052818 s, avg reconstruction error = 1.610e-13
Batch size B = 4096 → elapsed time = 0.076113 s, avg reconstruction error = 1.615e-13
Batch size B = 8192 → elapsed time = 0.115998 s, avg reconstruction error = 1.609e-13


Summary of errors and timings

In [None]:
B_array = np.array(batch_sizes)
elapsed_array = np.array(elapsed_times)
error_array = np.array(avg_errors)

# Plot Timing vs. Batch Size (log-log scale)
plt.figure(figsize=(8, 4))
plt.plot(B_array, elapsed_array, marker='o')
plt.title(f"Batched SVD Timing vs. Batch Size (m={m}, n={n})")
plt.xlabel("Batch Size (B)")
plt.ylabel("Elapsed Time (seconds)")
plt.xscale('log', base=2)
plt.yscale('log')
plt.grid(True, which='both', ls='--')
plt.tight_layout()
plt.show()

# Plot Average Reconstruction Error vs. Batch Size (log-log scale)
plt.figure(figsize=(8, 4))
plt.plot(B_array, error_array, marker='o')
plt.title(f"Average Reconstruction Error vs. Batch Size (m={m}, n={n})")
plt.xlabel("Batch Size (B)")
plt.ylabel("Average Error (Frobenius Norm)")
plt.xscale('log', base=2)
plt.yscale('log')
plt.grid(True, which='both', ls='--')
plt.tight_layout()
plt.show()

Saving the results

In [None]:
# Save data to CSV
filename = "resultsPyTorch.csv"
with open(filename, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["elapsed_time", "error"])  # Header row
    writer.writerows(zip(elapsed_times, errors))

print(f"Data saved to {filename}")