In [10]:
import torch

def build_upper_triangle_index_matrix(n: int) -> torch.Tensor:
    """
    Returns an (n x n) matrix 'index_matrix' with:
      index_matrix[r, c] = k  if r < c
      index_matrix[r, c] = -1 if r >= c
    where k matches the ordering of torch.triu_indices(n,n,offset=1).
    """
    index_matrix = torch.full((n, n), -1, dtype=torch.long)
    # (row, col) pairs for r < c, in the same order as square_to_flat
    r, c = torch.triu_indices(n, n, offset=1)  # length m = n*(n-1)//2
    # Fill those positions with 0..m-1
    index_matrix[r, c] = torch.arange(r.numel(), dtype=torch.long)
    return index_matrix


def permutation_to_flat_via_index_matrix(perm: torch.Tensor) -> torch.Tensor:
    """
    For a given permutation 'perm' of [0..n-1], return the
    index array that can reorder the original flat vector
    to match permuting the square matrix, then flattening.
    """
    n = perm.size(0)
    # 1) Build the lookup
    index_matrix = build_upper_triangle_index_matrix(n)
    # 2) Permute the lookup the same way as the data
    perm_index_matrix = index_matrix[perm][:, perm]
    # 3) Now flatten the upper triangle of this permuted matrix
    r, c = torch.triu_indices(n, n, offset=1)
    new_indices = perm_index_matrix[r, c]
    return new_indices
def square_to_flat(square_rdm: torch.Tensor) -> torch.Tensor:
    # The same flattening order as build_upper_triangle_index_matrix
    n_ = square_rdm.size(0)
    r_, c_ = torch.triu_indices(n_, n_, offset=1)
    return square_rdm[r_, c_]

# ---------------------------------------------



In [78]:

n = 5
matrix = torch.arange(n*n).reshape(n, n).triu(1)
matrix = matrix + matrix.T
matrix = matrix.fill_diagonal_(0)

permutation = torch.randperm(n)

matrix[permutation, :][:, permutation]

tensor([[ 0,  3, 13,  8, 19],
        [ 3,  0,  2,  1,  4],
        [13,  2,  0,  7, 14],
        [ 8,  1,  7,  0,  9],
        [19,  4, 14,  9,  0]])

In [13]:
flat_permuted
flat_permuted_via_flat

tensor([19, 19, 19, 19, 19, 19, 14, 19,  9,  4])

In [6]:
n = 10
example_matrix = torch.arange(n*n).reshape(n, n)
flat_example_matrix = square_to_flat(example_matrix)

permutation = torch.randperm(n)



# Permute the full matrix, then flatten
permuted_matrix = example_matrix[permutation][:, permutation]
flat_permuted = square_to_flat(permuted_matrix)

# Directly apply the "flat" permutation
flat_perm_indices = permutation_to_flat(permutation)
flat_permuted_via_flat = flat_example_matrix[flat_perm_indices]

# Check they match
print(torch.allclose(flat_permuted, flat_permuted_via_flat))

print(flat_permuted)

print(flat_permuted_via_flat)

False
tensor([81, 83, 85, 89, 84, 87, 86, 82, 80, 13, 15, 19, 14, 17, 16, 12, 10, 35,
        39, 34, 37, 36, 32, 30, 59, 54, 57, 56, 52, 50, 94, 97, 96, 92, 90, 47,
        46, 42, 40, 76, 72, 70, 62, 60, 20])
tensor([18, 38, 58, 89, 48, 78, 68, 28,  8, 13, 15, 19, 14, 17, 16, 12,  1, 35,
        39, 34, 37, 36, 23,  3, 59, 45, 57, 56, 25,  5, 49, 79, 69, 29,  9, 47,
        46, 24,  4, 67, 27,  7, 26,  6,  2])


In [81]:
permutation = torch.randperm(10)
random_matrix = torch.arange(100).reshape(10, 10)


# First method to permute and convert to flat
permuted_random_matrix = random_matrix[permutation, :][:, permutation]
flat_permuted_random_matrix = square_to_flat(permuted_random_matrix)

# Second method
flat_permutation = permutation_to_flat_perm(permutation)
flat_random_matrix = square_to_flat(random_matrix)
flat_permuted_random_matrix_2 = flat_random_matrix[flat_permutation]

assert (flat_permuted_random_matrix_2 == flat_permuted_random_matrix).all().item(), "NOT EQUAL"


AssertionError: NOT EQUAL

In [99]:
import torch

def square_to_flat(matrix: torch.Tensor) -> torch.Tensor:
    """
    Given an n x n PyTorch tensor (symmetric matrix),
    return the flattened upper triangle (excluding the diagonal).

    The returned 1D tensor has length n*(n-1)//2 and
    follows the order of torch.triu_indices(n, n, offset=1).
    """
    n = matrix.shape[0]
    i, j = torch.triu_indices(n, n, offset=1)  # shape (m, ), m=n*(n-1)//2
    return matrix[i, j]

def permutation_to_flat(perm: torch.Tensor) -> torch.Tensor:
    """
    Given a permutation 'perm' of size n that is used to reorder
    both rows and columns of an n x n matrix, return the index
    array (length n*(n-1)//2) that can reorder the flat upper-triangle
    to produce the same result as permuting the matrix first, then flattening.

    Usage:
        flat_reordered = flat_original[ new_indices ]

    so that 'flat_reordered' matches:
        square_to_flat( matrix[perm][:, perm] ).
    """
    n = perm.numel()
    # Compute inverse permutation (map old_index -> new_index)
    # so if perm[new_idx] = old_idx, then invPerm[old_idx] = new_idx
    invPerm = torch.empty_like(perm)
    invPerm[perm] = torch.arange(n, device=perm.device)

    # Original upper-tri pairs
    i, j = torch.triu_indices(n, n, offset=1)  # shape: (m,)
    # Apply inverse perm to each coordinate
    new_i = invPerm[i]
    new_j = invPerm[j]
    # Enforce new_i < new_j
    lower = torch.min(new_i, new_j)
    upper = torch.max(new_i, new_j)

    # Convert (lower, upper) to new flat index
    # Formula consistent with torch.triu_indices(n, n, offset=1) ordering
    new_indices = ((n * (n - 1)) // 2) \
                  - ((n - lower) * (n - lower - 1)) // 2 \
                  + (upper - lower - 1)

    # new_indices tells us "old -> new" positions, but we typically
    # want an array 'order' s.t. flat_reordered = flat_orig[order].
    # That means 'order[new_pos] = old_pos'. So do argsort:
    order = torch.argsort(new_indices)
    return order

def permutation_to_flat_no_sort(perm: torch.Tensor) -> torch.Tensor:
    """
    Same logic as before, but we skip argsort by building
    the 'order' (new->old) array in O(N) time.
    """
    n = perm.numel()
    device = perm.device

    # Inverse permutation
    invPerm = torch.empty_like(perm)
    invPerm[perm] = torch.arange(n, device=device)

    # Original upper-tri indices
    i, j = torch.triu_indices(n, n, offset=1, device=device)
    new_i = invPerm[i]
    new_j = invPerm[j]

    # Ensure new_i < new_j
    lower = torch.min(new_i, new_j)
    upper = torch.max(new_i, new_j)

    # "old -> new" position
    new_indices = ((n*(n-1))//2
                   - ((n - lower)*(n - lower - 1))//2
                   + (upper - lower - 1))

    # Build the "new -> old" in O(N) with direct indexing
    # length = n*(n-1)//2
    N = new_indices.numel()
    order = torch.empty(N, dtype=torch.long, device=device)
    # for each old index k, new_indices[k] = new position
    # so we do order[new_pos] = old_pos
    old_positions = torch.arange(N, device=device)
    order[new_indices] = old_positions

    return order



# ---------------------------------------------------
# Demonstration
if __name__ == "__main__":
    n = 750
    # Create an n x n matrix where the upper triangle holds distinct values
    M = torch.zeros(n, n, dtype=torch.long)
    idx_upper_i, idx_upper_j = torch.triu_indices(n, n, offset=1)
    for k in range(idx_upper_i.numel()):
        i_ = idx_upper_i[k]
        j_ = idx_upper_j[k]
        # Assign the flat index k to positions (i_, j_) and (j_, i_)
        M[i_, j_] = k
        M[j_, i_] = k

    print("Original matrix:\n", M)

    # Flatten the upper triangle
    orig_flat = square_to_flat(M)
    print("Flattened upper-tri:", orig_flat)

    # Example permutation
    perm = torch.randperm(n)
    print("Permutation (new order of original indices):", perm)

    # Permute the full matrix, then flatten
    M_perm = M[perm][:, perm]
    new_flat = square_to_flat(M_perm)
    print("Permuted matrix flattened:", new_flat)

    # Compute the flat permutation indices to reorder the original flat vector
    flat_order = permutation_to_flat_no_sort(perm)
    flat_reordered = orig_flat[flat_order]
    print("Reordered flat directly:", flat_reordered)

    # Confirm both approaches match
    print("Matches?", torch.allclose(new_flat, flat_reordered))


Original matrix:
 tensor([[     0,      0,      1,  ...,    746,    747,    748],
        [     0,      0,    749,  ...,   1494,   1495,   1496],
        [     1,    749,      0,  ...,   2241,   2242,   2243],
        ...,
        [   746,   1494,   2241,  ...,      0, 280872, 280873],
        [   747,   1495,   2242,  ..., 280872,      0, 280874],
        [   748,   1496,   2243,  ..., 280873, 280874,      0]])
Flattened upper-tri: tensor([     0,      1,      2,  ..., 280872, 280873, 280874])
Permutation (new order of original indices): tensor([ 66, 467,  32, 134,   2,  90,  53, 385, 371, 109, 147, 340, 464, 159,
        563, 531, 382, 625, 267, 204, 338, 665,   5, 178, 503, 524, 237, 521,
        597, 556, 717, 437, 137, 570, 651,  96,  84, 589, 458, 598, 393, 146,
        553, 640, 522, 506,  78, 112, 601, 536,  77, 658, 350, 648, 525, 372,
        662, 461, 688, 493,  54, 343,  56, 220, 257, 115, 315, 101, 205, 497,
         22, 728, 446, 428, 180, 117, 124, 173, 449, 325, 473, 65

In [43]:
flat_permuted_random_matrix
flat_permuted_random_matrix_2

tensor([37, 57, 17, 78, 27, 47,  7, 67, 79, 35, 13, 38, 23, 34,  3, 36, 39, 15,
        58, 25, 45,  5, 56, 59, 18, 12, 14,  1, 16, 19, 28, 48,  8, 68, 89, 24,
         2, 26, 29,  4, 46, 49,  6,  9, 69])

In [None]:
import numpy as np

def square_to_flat(matrix: np.ndarray) -> np.ndarray:
    """
    Given an n x n square (symmetric) matrix,
    return the flattened upper triangle (excluding the diagonal).

    The returned vector has length n*(n-1)//2 and uses
    the same ordering as np.triu_indices(n, k=1).
    """
    n = matrix.shape[0]
    i, j = np.triu_indices(n, k=1)
    return matrix[i, j]

def permutation_to_flat(perm: np.ndarray) -> np.ndarray:
    """
    Given a permutation 'perm' of [0..n-1] that reorders rows and columns
    of an n x n matrix, return the corresponding permutation of the
    flattened upper-triangle indices (length n*(n-1)//2), so that:

       flat_permuted = flat_original[ new_indices ]

    is equivalent to permuting the full matrix first and then flattening.
    """
    # size
    n = len(perm)
    # Inverse permutation (maps original index -> new index)
    invPerm = np.argsort(perm)

    # Original upper-tri pairs
    i, j = np.triu_indices(n, k=1)  # shape (m,), m = n*(n-1)//2

    # Apply inverse perm to each coordinate
    new_i = invPerm[i]
    new_j = invPerm[j]

    # Ensure new_i < new_j (upper-tri)
    new_i, new_j = np.minimum(new_i, new_j), np.maximum(new_i, new_j)

    # Convert (new_i, new_j) to new flat index
    # The formula consistent with the ordering from np.triu_indices(n, k=1)
    new_indices = ((n * (n - 1)) // 2) \
                  - ((n - new_i) * (n - new_i - 1)) // 2 \
                  + (new_j - new_i - 1)

    # new_indices tells where each original flat index should go.
    # But typically we want the "inverse" permutation so that if we do
    #   flat_permuted = flat_original[some_indices]
    # we get the correct effect.
    #
    # The simplest approach: new_indices is "old -> new".  So we do an argsort
    # to get "new -> old", i.e. the array that we can use to reorder the original flat
    order = np.argsort(new_indices)

    return order

# ------------------------------
# Demonstration
if __name__ == "__main__":
    n = 4
    M = np.zeros((n, n), int)
    i_up, j_up = np.triu_indices(n, k=1)
    for k, (r, c) in enumerate(zip(i_up, j_up)):
        M[r, c] = k
        M[c, r] = k

    print("Original matrix:\n", M)
    orig_flat = square_to_flat(M)
    print("Flattened (upper-tri):", orig_flat)

    # Example permutation
    
    print("Permutation:", perm)

    # Permute the full matrix, then flatten
    M_perm = M[perm][:, perm]
    new_flat = square_to_flat(M_perm)
    print("Permuted matrix flattened:", new_flat)

    # Compute the flat permutation indices
    flat_order = permutation_to_flat(perm)
    flat_reordered = orig_flat[flat_order]
    print("Reordered flat directly:", flat_reordered)

    # They should match
    print("Matches?", np.allclose(new_flat, flat_reordered))


Original matrix:
 [[0 0 1 2]
 [0 0 3 4]
 [1 3 0 5]
 [2 4 5 0]]
Flattened (upper-tri): [0 1 2 3 4 5]
Permutation: [2 0 1 3]
Permuted matrix flattened: [1 3 5 0 2 4]
Reordered flat directly: [1 3 5 0 2 4]
Matches? True


In [100]:
import torch

def batched_flat_permutation_no_sort(perm_batch: torch.Tensor) -> torch.Tensor:
    """
    Given a batch of permutations of shape (n_perms, n),
    compute the corresponding 'flat permutation' of length N = n*(n-1)//2
    for each permutation. The result is an integer tensor of shape (n_perms, N).

    Specifically:
      - perm_batch[p] is a permutation of [0..n-1].
      - The returned 'order' has the same shape (n_perms, N).
      - For each p, order[p] is a 1D index array of length N such that:
            flat_reordered = flat_original[ order[p] ]
        matches flattening the matrix after applying perm to rows/columns.

    This version avoids any O(N log N) sorting by building 'new->old' in O(N) time.
    """

    # --------------------------------------------------------------------
    # 0) Setup and shapes
    n_perms, n = perm_batch.shape
    device = perm_batch.device
    N = n * (n - 1) // 2  # number of upper-tri (excl. diagonal) entries

    # --------------------------------------------------------------------
    # 1) Build inverse permutation for each row in a single vectorized pass
    #    - invPerm[p, :] is the inverse of perm_batch[p, :]
    #    - "scatter" approach: for each p, for each k in [0..n-1],
    #         invPerm[p, perm_batch[p,k]] = k
    invPerm = torch.empty_like(perm_batch)
    # Create row indices and col indices for assignment
    rows = torch.arange(n_perms, device=device).unsqueeze(1).expand(-1, n)  # (n_perms, n)
    cols = perm_batch  # shape (n_perms, n)
    vals = torch.arange(n, device=device).unsqueeze(0).expand(n_perms, -1)  # (n_perms, n)
    invPerm[rows, cols] = vals  # fill inverse permutations

    # --------------------------------------------------------------------
    # 2) Precompute all (i, j) for the upper triangle (i<j)
    #    We'll broadcast these across permutations in a vectorized manner
    i_j = torch.triu_indices(n, n, offset=1, device=device)  # shape (2, N)
    i, j = i_j[0], i_j[1]  # each shape (N,)

    # --------------------------------------------------------------------
    # 3) Map (i, j) -> (new_i, new_j) via inverse permutation, for all p
    #    new_i[p, k] = invPerm[p, i[k]]
    #    new_j[p, k] = invPerm[p, j[k]]
    #    shape => (n_perms, N)
    new_i = invPerm[:, i]  # (n_perms, N)
    new_j = invPerm[:, j]  # (n_perms, N)

    # --------------------------------------------------------------------
    # 4) Ensure new_i < new_j
    lower = torch.min(new_i, new_j)
    upper = torch.max(new_i, new_j)

    # --------------------------------------------------------------------
    # 5) Compute "old->new" positions:
    #    new_indices[p, k] = new position of the old index k in flat space
    #    The formula consistent with torch.triu_indices(n,n,offset=1):
    #
    #       new_idx = N - ((n-lower)*(n-lower-1))//2 + (upper - lower - 1)
    #
    #    This yields shape (n_perms, N).
    tmp = (n - lower) * (n - lower - 1) // 2
    new_indices = (N - tmp) + (upper - lower - 1)

    # --------------------------------------------------------------------
    # 6) Build the final "new->old" order in O(N) time (no sorting).
    #    For each row p, for each old index k, we do:
    #        new_pos = new_indices[p, k]
    #        order[p, new_pos] = k
    #    We'll do this with advanced indexing in one shot.
    order = torch.empty_like(new_indices)  # (n_perms, N)
    old_positions = torch.arange(N, device=device).unsqueeze(0).expand(n_perms, -1)  # shape (n_perms, N)

    # row indices => shape (n_perms, N)
    rows_big = torch.arange(n_perms, device=device).unsqueeze(1).expand(-1, N)
    # col indices => new_indices, shape (n_perms, N)

    order[rows_big, new_indices] = old_positions

    return order

# ------------------------------------------------------------------------------
# DEMO with smaller scale
if __name__ == "__main__":
    n_perms = 2
    n = 5
    # Example permutations
    perm_batch = torch.tensor([
        [3, 1, 4, 0, 2],  # each row is a permutation of [0..4]
        [1, 2, 3, 4, 0],
    ], dtype=torch.long, device='cpu')

    # Get the batched index reorder
    order_batched = batched_flat_permutation_no_sort(perm_batch)
    print("Batched order shape:", order_batched.shape)  # (2, n*(n-1)//2) = (2,10)

    # Let's verify correctness for the first permutation
    # 1) Build a test matrix M with distinct upper-tri entries
    M = torch.zeros(n, n, dtype=torch.long)
    i_up, j_up = torch.triu_indices(n, n, offset=1)
    for idx in range(i_up.numel()):
        M[i_up[idx], j_up[idx]] = idx
        M[j_up[idx], i_up[idx]] = idx
    print("Original matrix:\n", M)

    # Flatten upper triangle
    orig_flat = M[i_up, j_up]
    print("Original flat:", orig_flat)

    # Permute rows/cols with perm_batch[0] -> flatten
    M0 = M[perm_batch[0]][:, perm_batch[0]]
    perm_flat = M0[i_up, j_up]
    print("Permutation #0 flattened:", perm_flat)

    # Reorder original flat directly
    order_0 = order_batched[0]
    flat_0 = orig_flat[order_0]
    print("Reordered flat #0:", flat_0)
    print("Matches #0?", torch.allclose(perm_flat, flat_0))

    # Similarly check permutation #1
    M1 = M[perm_batch[1]][:, perm_batch[1]]
    perm_flat_1 = M1[i_up, j_up]
    order_1 = order_batched[1]
    flat_1 = orig_flat[order_1]
    print("Permutation #1 flattened:", perm_flat_1)
    print("Reordered flat #1:", flat_1)
    print("Matches #1?", torch.allclose(perm_flat_1, flat_1))


Batched order shape: torch.Size([2, 10])
Original matrix:
 tensor([[0, 0, 1, 2, 3],
        [0, 0, 4, 5, 6],
        [1, 4, 0, 7, 8],
        [2, 5, 7, 0, 9],
        [3, 6, 8, 9, 0]])
Original flat: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Permutation #0 flattened: tensor([5, 9, 2, 7, 6, 0, 4, 3, 8, 1])
Reordered flat #0: tensor([5, 9, 2, 7, 6, 0, 4, 3, 8, 1])
Matches #0? True
Permutation #1 flattened: tensor([4, 5, 6, 0, 7, 8, 1, 9, 2, 3])
Reordered flat #1: tensor([4, 5, 6, 0, 7, 8, 1, 9, 2, 3])
Matches #1? True


In [101]:
import torch

def batched_flat_permutation_no_sort(perm_batch: torch.Tensor) -> torch.Tensor:
    """
    Given a batch of permutations of shape (n_perms, n), compute
    the 'flat-permutation' index array of shape (n_perms, N), where
    N = n*(n-1)//2.

    - perm_batch[p] is a permutation of [0..n-1].
    - The returned array, say 'order', satisfies:
         flat_reordered = flat_original[ order[p] ]
      which is equivalent to:
         1) reshaping flat_original back to (n x n) (upper tri only)
         2) permuting rows/columns of that n x n with perm_batch[p]
         3) flattening the upper triangle again.
    
    This version avoids an O(N log N) sort, instead building the
    new->old mapping in O(N) time for each permutation.
    """

    # Shape info
    n_perms, n = perm_batch.shape
    device = perm_batch.device
    # Number of strictly upper-triangle entries
    N = n * (n - 1) // 2

    # 1) Build inverse permutation for each p in perm_batch
    #    invPerm[p, perm_batch[p, k]] = k
    invPerm = torch.empty_like(perm_batch)
    rows = torch.arange(n_perms, device=device).unsqueeze(1).expand(-1, n)  # (n_perms, n)
    cols = perm_batch
    vals = torch.arange(n, device=device).unsqueeze(0).expand(n_perms, -1)  # (n_perms, n)
    invPerm[rows, cols] = vals

    # 2) Precompute all (i, j) for upper triangle (i < j)
    i_j = torch.triu_indices(n, n, offset=1, device=device)
    i, j = i_j[0], i_j[1]  # each shape (N,)

    # 3) For each permutation p, map (i, j) -> (new_i, new_j)
    new_i = invPerm[:, i]  # (n_perms, N)
    new_j = invPerm[:, j]  # (n_perms, N)

    # 4) Ensure new_i < new_j
    lower = torch.min(new_i, new_j)
    upper = torch.max(new_i, new_j)

    # 5) "old->new" positions: new_indices[p, k] = new position of old index k
    tmp = (n - lower) * (n - lower - 1) // 2
    new_indices = (N - tmp) + (upper - lower - 1)  # shape (n_perms, N)

    # 6) Build "new->old" order in O(N) time (no sort):
    #    order[p, new_pos] = old_pos
    order = torch.empty_like(new_indices)
    old_positions = torch.arange(N, device=device).unsqueeze(0).expand(n_perms, -1)  # (n_perms, N)
    row_ids = torch.arange(n_perms, device=device).unsqueeze(1).expand(-1, N)        # (n_perms, N)

    order[row_ids, new_indices] = old_positions  # scatter: new->old

    return order


# ------------------------------
# Example usage
if __name__ == "__main__":
    n_perms = 2
    n = 5
    # Sample permutations
    perm_batch = torch.tensor([
        [3, 1, 4, 0, 2],
        [1, 2, 3, 4, 0],
    ], dtype=torch.long)

    order_batched = batched_flat_permutation_no_sort(perm_batch)
    print("order_batched.shape:", order_batched.shape)  # => (2, 10) for n=5

    # Demo for a single matrix M with distinct upper-tri values
    M = torch.zeros(n, n, dtype=torch.long)
    i_up, j_up = torch.triu_indices(n, n, offset=1)
    for k in range(i_up.numel()):
        r, c = i_up[k], j_up[k]
        M[r, c] = k
        M[c, r] = k

    # Flatten M
    orig_flat = M[i_up, j_up]

    # Perm #0
    perm0 = perm_batch[0]
    M0 = M[perm0][:, perm0]
    perm0_flat = M0[i_up, j_up]
    direct0 = orig_flat[order_batched[0]]
    print("Matches perm #0?", torch.allclose(perm0_flat, direct0))

    # Perm #1
    perm1 = perm_batch[1]
    M1 = M[perm1][:, perm1]
    perm1_flat = M1[i_up, j_up]
    direct1 = orig_flat[order_batched[1]]
    print("Matches perm #1?", torch.allclose(perm1_flat, direct1))


order_batched.shape: torch.Size([2, 10])
Matches perm #0? True
Matches perm #1? True
