In [1]:
import torch

In [3]:
import torch

def weighted_qr_ignore_zero_rows(A, W):
    """
    Perform a 'weighted QR' on matrix A w.r.t. diag(W), ignoring rows where W=0.
    Returns Q, R such that:
      1) A = Q R   (reconstructs only the rows where W>0 exactly)
      2) Q^T diag(W) Q = I in the subspace of nonzero W.
         Rows with W=0 get filled with zeros in Q by default.
    """
    # A: (m, n)
    # W: (m,) diagonal entries, may have zeros
    assert A.shape[0] == W.shape[0], "Mismatched shapes between A and W."

    # 1) Identify nonzero-weight rows
    mask = (W > 0)
    
    # 2) Extract submatrix and subweights
    A_sub = A[mask, :]             # shape (m_sub, n)
    W_sub = W[mask]                # shape (m_sub,)

    # If no zero weights, just do a normal weighted QR
    if A_sub.shape[0] == A.shape[0]:
        # i.e., mask is all True
        # standard weighted qr (assumes W_sub all positive)
        return weighted_qr(A, W)   # from the previous helper you wrote

    # 3) Compute Weighted QR on submatrix
    Q_sub, R = weighted_qr(A_sub, W_sub)  # as previously defined

    # 4) Reassemble full-size Q
    m, n = A.shape
    Q = A.new_zeros((m, n))  # same dtype/device
    Q[mask, :] = Q_sub       # put back the valid rows

    return Q, R

def weighted_qr(A, W):
    """
    Standard Weighted QR: A = Q R, with Q^T diag(W) Q = I.
    (Requires that W > 0 on all diagonal entries.)
    """
    W_sqrt = W.sqrt()      # shape (m,)
    W_inv_sqrt = 1.0 / W_sqrt

    # Multiply each row i by sqrt(W[i])
    B = W_sqrt.unsqueeze(-1) * A
    # Standard QR on B
    Q_std, R = torch.linalg.qr(B, mode='reduced')
    # Map Q_std back
    Q = W_inv_sqrt.unsqueeze(-1) * Q_std
    return Q, R

# -------------------------------
# Example usage
m, n = 6, 3
A = torch.randn(m, n)
# Suppose some rows have zero weight
W_vals = torch.tensor([1.0, 0.0, 2.5, 0.0, 0.1, 0.0])  

Q, R = weighted_qr_ignore_zero_rows(A, W_vals)

print("A.shape =", A.shape)
print("Q.shape =", Q.shape)
print("R.shape =", R.shape)
A_approx = Q @ R

# Reconstruction check on the nonzero-weight rows
mask = (W_vals > 0)
err_reconstruction = (A_approx[mask] - A[mask]).norm().item()
print("Reconstruction error (rows with W>0) =", err_reconstruction)

# Weighted orthogonality check on subspace
W_mat = torch.diag(W_vals)
check = Q.T @ W_mat @ Q  # shape (n, n)
I_n = torch.eye(n, dtype=Q.dtype, device=Q.device)
err_QTWQ = (check - I_n).norm().item()
print("||Q^T W Q - I|| on subspace (nonzero rows) =", err_QTWQ)

# For rows where W=0, we didn't constrain Q at all, so Q is zero in those rows.


A.shape = torch.Size([6, 3])
Q.shape = torch.Size([6, 3])
R.shape = torch.Size([3, 3])
Reconstruction error (rows with W>0) = 1.3677447441295953e-07
||Q^T W Q - I|| on subspace (nonzero rows) = 1.88081486385272e-07


In [6]:
torch.tensor((0,1,2,3))[1:]

tensor([1, 2, 3])

In [None]:
weights = torch.tensor([1,0,3,4,5])

mask = weights > 0

W_og = torch.diag(weights).float()

W = torch.diag(torch.tensor(weights[mask]))
W_inv = torch.diag(1/torch.tensor(weights[mask]))

W_sqrt = W.sqrt()      # shape (batch_size, m)
W_inv_sqrt = W_inv.sqrt()      # shape (batch_size, m)


alpha = torch.ones(2,1) * 2

true_energy = 3

A = torch.randn(5, 2)

A_mask = A[mask, :]

A = W_sqrt @ A_mask

Q, R = torch.linalg.qr(A, mode='reduced')

Q_full = torch.zeros((4,4))
R_full = torch.zeros((4,2))


Q_full[:,:2] = Q
R_full[:2,:] = R


alpha = R_full @ alpha
print(alpha.shape)
alpha = alpha / (alpha.T @ alpha)
alpha = alpha * torch.sqrt(torch.tensor(true_energy))

Q_full = W_inv_sqrt @ Q_full

print((Q_full @ R_full).shape)
print((Q_full @ alpha).shape)

full_QRa = torch.zeros((5, 1)).float()
full_QRa[mask, :] = Q_full @ alpha

energy = (full_QRa).T @ W_og @ (full_QRa)

print(Q_full @ R_full)
print(A)
print(energy)


torch.Size([4, 1])


  W = torch.diag(torch.tensor(weights[mask]))
  W_inv = torch.diag(1/torch.tensor(weights[mask]))


IndexError: The shape of the mask [5] at index 0 does not match the shape of the indexed tensor [4, 1] at index 0

In [39]:
import torch

# 1) Setup
weights = torch.tensor([1, 0, 3, 4, 5], dtype=torch.float32)
mask = (weights > 0)               # e.g. [True, False, True, True, True]

W_og = torch.diag(weights)         # (5x5) original weighting
true_energy = 3.0

A_orig = torch.randn(5, 2)         # The original full A

# 2) Submatrix for nonzero weights
A_sub = A_orig[mask, :]            # shape (4,2)
w_sub = weights[mask]              # shape (4,)

# 3) Weighted matrix B = W_sqrt * A_sub
W_sqrt = torch.diag(w_sub.sqrt())         # (4x4)
B = W_sqrt @ A_sub                        # (4x2)

# 4) Standard QR on B
Q_std, R = torch.linalg.qr(B, mode='reduced')
# Q_std is (4,2),  R is (2,2)

# 5) Weighted Q_sub = W_inv_sqrt @ Q_std
W_inv_sqrt = torch.diag(1.0 / w_sub.sqrt())
Q_sub = W_inv_sqrt @ Q_std   # (4,2), and Q_sub^T diag(w_sub) Q_sub = I

# 6) Insert Q_sub into a full Q_full that has 5 rows
Q_full = torch.zeros((5, 2), dtype=A_orig.dtype)
Q_full[mask, :] = Q_sub  # fill the 4 rows that matter

# Check: Q_sub@R ~ A_sub
recon_sub = Q_sub @ R
recon_error_sub = (recon_sub - A_sub).norm().item()

print(f"Reconstruction error on nonzero-weight rows: {recon_error_sub:.3e}")

# 7) Example alpha
alpha = torch.ones((2,1)) * 2  # shape (2,1), just a test
# Suppose you want to normalize alpha so that (Q_full alpha) has energy = true_energy

# Weighted norm = alpha^T (Q_full^T W_og Q_full) alpha
# But Q_full^T W_og Q_full = "I" in the subspace, and 0 for the zero row
# So effectively it's alpha^T alpha. Let's check:

M = Q_full.T @ W_og @ Q_full  # shape (2,2)
# If only 1 zero weight row, rank might still be 2. Let's see:
print("Q^T W Q =")
print(M)

# 8) Normalize alpha
current_energy = alpha.T @ M @ alpha
alpha = alpha / torch.sqrt(current_energy)  # => energy is now 1
alpha = alpha * torch.sqrt(torch.tensor(true_energy))
# => energy is now true_energy

# 9) form the final vector = Q_full @ alpha
full_QRa = Q_full @ alpha    # shape (5,1)
# Weighted energy:
energy = (full_QRa.T @ W_og @ full_QRa).item()

print(f"Final energy = {energy:.3f} (target = {true_energy})")


Reconstruction error on nonzero-weight rows: 3.039e-07
Q^T W Q =
tensor([[1.0000e+00, 3.7253e-09],
        [1.8626e-09, 1.0000e+00]])
Final energy = 3.000 (target = 3.0)


In [40]:
import torch

# 1) Setup random A and diagonal weights (some zero)
m, n = 5, 3
A = torch.randn(m, n)
W_vals = torch.tensor([2.0, 0.0, 3.0, 5.0, 0.0])  # shape (5,)

print("A:\n", A, "\n")
print("W (diagonal entries):\n", W_vals, "\n")

# 2) Identify nonzero rows
mask = (W_vals > 0)
A_sub = A[mask, :]          # submatrix with nonzero-weight rows
w_sub = W_vals[mask]        # corresponding nonzero weights

# 3) Form B = sqrt(W_sub)*A_sub, then do unweighted QR
W_sub_sqrt = w_sub.sqrt()              # shape (m_sub,)
B = W_sub_sqrt.unsqueeze(-1) * A_sub   # multiply each row by sqrt(w_sub)
Q_std, R = torch.linalg.qr(B, mode='reduced')   # unweighted QR

# 4) Convert Q_std -> Q_sub = invsqrt(W_sub)*Q_std
W_sub_invsqrt = 1.0 / W_sub_sqrt
Q_sub = W_sub_invsqrt.unsqueeze(-1) * Q_std

# 5) Reassemble a full Q of shape (m, n), zero in rows where W=0
Q = torch.zeros_like(A)  # same shape (5,3)
Q[mask, :] = Q_sub

# Check Q^T W Q = I
W_mat = torch.diag(W_vals)
lhs = Q.T @ W_mat @ Q
print("Q^T W Q =\n", lhs, "\n")

# 6) Demonstrate that (Q R alpha)^T W (Q R alpha) = ||R alpha||^2
alpha = torch.randn(n, 1)
u_weighted = Q @ (R @ alpha)             # (m x 1)
energy_weighted = (u_weighted.T @ W_mat @ u_weighted).item()
energy_unweighted = (R @ alpha).norm().pow(2).item()

print("Energy (weighted)   = ", energy_weighted)
print("Energy (unweighted) = ", energy_unweighted)
print("Difference          = ", abs(energy_weighted - energy_unweighted))


A:
 tensor([[ 1.2382, -0.0457,  0.0034],
        [-0.0845,  1.1376,  0.3131],
        [ 0.5250, -0.8093,  0.4620],
        [ 0.3376,  2.1880, -0.6366],
        [ 1.0187,  0.5949, -0.8949]]) 

W (diagonal entries):
 tensor([2., 0., 3., 5., 0.]) 

Q^T W Q =
 tensor([[ 1.0000e+00,  0.0000e+00,  5.9605e-08],
        [ 0.0000e+00,  1.0000e+00, -2.9802e-08],
        [ 5.9605e-08,  0.0000e+00,  1.0000e+00]]) 

Energy (weighted)   =  3.2819647789001465
Energy (unweighted) =  3.2819645404815674
Difference          =  2.384185791015625e-07


In [142]:
batch_size = 2

# Create batched inputs
B_re = torch.rand((batch_size, 4, 2))
B_im = torch.rand((batch_size, 4, 2))
B = B_re + 1j * B_im

un_alpha = torch.rand((batch_size, 2))
un_alpha = torch.complex(un_alpha, torch.zeros_like(un_alpha))

W_vals = torch.tensor([2.0, 0.0, 3.0, 5.0])  # (batch_size, 4)

mask = (W_vals > 0)  # (batch_size, 4)

# Handle each batch independently since QR is not batched
Q = torch.zeros_like(B)  # (batch_size, 4, 3)

B_sub = B[:,mask,:]          # (n_nonzero, 3)

w_sub = W_vals[mask]        # (n_nonzero,)

W_sub_sqrt = w_sub.sqrt()
W_sub_invsqrt = 1.0 / W_sub_sqrt

W_sub_sqrt_mat = torch.diag(W_sub_sqrt).unsqueeze(0).expand(B.shape[0],-1,-1)  # (n_nonzero, n_nonzero)
W_sub_invsqrt = torch.diag(W_sub_invsqrt).unsqueeze(0).expand(B.shape[0],-1,-1)  # (n_nonzero, n_nonzero)

W_sub_sqrt_mat = torch.complex(W_sub_sqrt_mat, torch.zeros_like(W_sub_sqrt_mat))
W_sub_invsqrt = torch.complex(W_sub_invsqrt, torch.zeros_like(W_sub_invsqrt))

B_sub = W_sub_sqrt_mat @ B_sub

Q_std, R = torch.linalg.qr(B_sub, mode='reduced')

Q_sub = W_sub_invsqrt @ Q_std


Q[:,mask,:]   = Q_sub

print(R.shape)
print(un_alpha.shape)
R_alpha = (R @ un_alpha.unsqueeze(-1))
print(f'R_alpha shape is {R_alpha.shape}')
energy_unweighted = torch.norm(R_alpha, p=2, dim=1, keepdim=True)  # keeping batch dimension
R_alpha = R_alpha / energy_unweighted * torch.sqrt(torch.tensor(2.0))

# Create batched diagonal weight matrix
W_mat = torch.complex(torch.diag_embed(W_vals), torch.zeros_like(torch.diag_embed(W_vals)))

print(Q.shape)
print(R_alpha.shape)
u_weighted = (Q @ R_alpha)  # (batch_size, 4, 1)
energy_weighted = torch.diagonal(
    torch.bmm(
        torch.bmm(u_weighted.transpose(-2,-1).conj(), W_mat.unsqueeze(0).expand(batch_size, -1, -1)),
        u_weighted
    ),
    dim1=-2, dim2=-1
).real

#print the product Q^TWQ
QTWQ = torch.bmm(torch.bmm(Q.transpose(-2,-1).conj(), W_mat.unsqueeze(0).expand(batch_size, -1, -1)), Q)
print("Q^T W Q =")
#print rounded
print(QTWQ.real.round())



print(R_alpha.shape)

print("Energy (weighted) per batch = ", energy_weighted)


torch.Size([2, 2, 2])
torch.Size([2, 2])
R_alpha shape is torch.Size([2, 2, 1])
torch.Size([2, 4, 2])
torch.Size([2, 2, 1])
Q^T W Q =
tensor([[[1., -0.],
         [-0., 1.]],

        [[1., 0.],
         [0., 1.]]])
torch.Size([2, 2, 1])
Energy (weighted) per batch =  tensor([[2.0000],
        [2.0000]])


In [130]:
tensy = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)

print(tensy)
mask = (torch.ones_like(tensy, dtype=torch.bool)).bool()
mask[2] = False

print(mask)

print(tensy[mask])

tensor([1., 2., 3., 4., 5.])
tensor([ True,  True, False,  True,  True])
tensor([1., 2., 4., 5.])
