In [1]:
import torch
N, H, T = 2, 6, 10
# Assume a is a tensor of shape [N, H, T]
a = torch.randn(N, H, T)  # Example tensor

# Step 1: Sum along the T dimension to get tensor b of shape [N, H]
b = torch.sum(a, dim=2)  # Shape of b is [N, H]

print(a)
print(b)
print(a.shape)
print(b.shape)

tensor([[[ 0.9041, -0.8796, -0.5632,  0.0797, -1.3168, -0.3419,  1.3030,
          -0.3451, -0.6063, -0.5245],
         [ 0.4215,  0.6848,  0.5073, -0.9174, -0.6083, -0.1182, -1.5017,
          -0.5311, -0.4541, -1.9571],
         [ 1.2158, -0.3853,  1.9686,  1.6066, -0.2141,  0.0164,  0.1152,
           0.6832, -0.5217, -0.2778],
         [ 0.4062, -0.3862,  0.2392, -0.8932, -0.3630,  0.6447, -0.3871,
          -0.2544, -0.4813,  1.3624],
         [-0.8946, -0.4769, -1.2336, -1.1485,  0.8232,  1.1682,  0.3541,
           1.4274,  1.1515, -1.1845],
         [ 1.5988, -1.7288, -0.7089, -1.3055, -1.1253, -1.4450, -0.4668,
          -0.6925,  1.6598, -0.0912]],

        [[ 0.9257,  0.5822, -1.3221, -0.8315,  0.1097, -0.6688, -1.7915,
           0.1704,  0.8910,  1.2546],
         [ 1.3104, -0.0158,  0.1452,  1.3375,  0.1436, -1.6218,  0.5001,
           0.9592,  1.8101,  1.2246],
         [-1.8643, -1.6465, -0.6212, -0.1397, -0.2657,  0.1240,  0.5665,
           1.3655, -1.0273,  1.9611],

In [2]:
0.9041 -0.8796-0.5632+  0.0797 -1.3168 -0.3419+  1.3030-0.3451-0.6063-0.5245

-2.2906000000000004

In [5]:
# Step 2: Select the top k values along the H dimension
# We use torch.topk to get the indices of the top k values
k = 4  # Specify the value of k
_, topk_indices = torch.topk(b, k, dim=1)  # Shape of topk_indices is [N, k]

print(topk_indices)
print(topk_indices.shape)

tensor([[2, 4, 3, 0],
        [1, 3, 5, 4]])
torch.Size([2, 4])


In [7]:
# Step 3: Use the indices to select the top k heads from the original tensor a
# Use torch.gather to select along the H dimension
topk_indices_expanded = topk_indices.unsqueeze(2).expand(-1, -1, T)  # Expand indices for proper gathering
a_topk = torch.gather(a, 1, topk_indices_expanded)  # Shape of a_topk is [N, k, T]

print(topk_indices_expanded)
print(topk_indices_expanded.shape)
print(a_topk)
print(a_topk.shape)
# Now a_topk has shape [N, k, T]

tensor([[[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
         [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
         [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]])
torch.Size([2, 4, 10])
tensor([[[ 1.2158, -0.3853,  1.9686,  1.6066, -0.2141,  0.0164,  0.1152,
           0.6832, -0.5217, -0.2778],
         [-0.8946, -0.4769, -1.2336, -1.1485,  0.8232,  1.1682,  0.3541,
           1.4274,  1.1515, -1.1845],
         [ 0.4062, -0.3862,  0.2392, -0.8932, -0.3630,  0.6447, -0.3871,
          -0.2544, -0.4813,  1.3624],
         [ 0.9041, -0.8796, -0.5632,  0.0797, -1.3168, -0.3419,  1.3030,
          -0.3451, -0.6063, -0.5245]],

        [[ 1.3104, -0.0158,  0.1452,  1.3375,  0.1436, -1.6218,  0.5001,
           0.9592,  1.8101,  1.2246],
         [ 0.6104,  1.1062, -0.2298,  0.2991,  0.2427,  2.1021, -0.7009,
          -0.24

In [9]:
import torch
N, H, T = 2, 10, 6

# Assume a is a tensor of shape [N, H, T]
a = torch.randn(N, H, T)  # Example tensor

# Step 1: Sum along the T dimension to get tensor b of shape [N, H]
b = torch.sum(a, dim=2)  # Shape of b is [N, H]

print(a)
print(b)
print(a.shape)
print(b.shape)


tensor([[[-1.6234,  0.5716, -3.0238, -0.1030, -0.6727, -0.0131],
         [ 0.3394, -0.5712,  1.8612,  0.1429, -2.4633, -0.9110],
         [ 0.1373,  0.6970, -1.1414,  1.7767,  0.6732,  1.2749],
         [-0.8944,  1.0215, -1.2825, -0.3689, -0.0744, -0.2944],
         [ 1.0076, -2.0619, -0.3731, -0.2602, -0.0738,  1.2605],
         [ 0.9844,  1.2328, -0.6156,  0.2858, -0.1262,  0.6503],
         [-0.8435,  0.4985,  1.4154,  1.6093, -1.0917,  0.1000],
         [ 1.2428, -0.0728, -0.5672, -0.2836, -0.2023,  0.7800],
         [ 2.0718, -1.2521,  0.1359, -1.3609,  0.6042,  0.7212],
         [-0.5903,  0.2496, -0.1047,  0.9970, -0.3020,  0.2568]],

        [[ 0.3867,  0.4413, -0.1570,  0.6820, -0.4854,  1.6211],
         [-0.1357,  1.1282, -0.9414,  1.3190, -0.0502, -0.7298],
         [ 1.5035,  0.0686, -0.3460, -1.1703,  0.4468, -2.4669],
         [-0.3202,  0.7076,  0.8119, -0.9865,  0.3860,  2.0085],
         [ 1.3758, -1.3889,  0.2530,  0.2001,  0.1194,  0.6831],
         [-0.7684, -0.4

In [10]:
# Step 2: Reshape the tensor b into [N, k, H//k]
k = 5  # Specify the value of k (number of groups)
H_group_size = H // k  # Each group will have H//k elements

# Reshape b to [N, k, H//k]
b_reshaped = b.view(N, k, H_group_size)  # Shape of b_reshaped is [N, k, H//k]

print(b_reshaped)
print(b_reshaped.shape)

tensor([[[-4.8645, -1.6021],
         [ 3.4178, -1.8931],
         [-0.5008,  2.4115],
         [ 1.6880,  0.8970],
         [ 0.9200,  0.5064]],

        [[ 2.4887,  0.5902],
         [-1.9644,  2.6073],
         [ 1.2425, -1.5704],
         [ 1.4897, -2.5452],
         [-0.7116, -1.4165]]])
torch.Size([2, 5, 2])


In [11]:
# Step 3: Select the top value in each group (1 value per group)
# Use torch.topk to get the indices of the top values within each group
_, topk_indices = torch.topk(b_reshaped, 1, dim=2)  # Shape of topk_indices is [N, k, 1]

# Step 4: Use the indices to gather the top values from the original tensor a
# First, expand topk_indices to match the shape for gathering from a
topk_indices_expanded = topk_indices.squeeze(2).unsqueeze(2).expand(-1, -1, T)  # Shape [N, k, T]

# Gather the topk heads from the original tensor a
a_topk = torch.gather(a.view(N, k, H_group_size, T), 2, topk_indices_expanded.unsqueeze(2)).squeeze(2)

# Now, a_topk has shape [N, k, T]

print(topk_indices)
print(topk_indices_expanded)
print(topk_indices_expanded.shape)
print(a_topk)
print(a_topk.shape)

tensor([[[1],
         [0],
         [1],
         [0],
         [0]],

        [[0],
         [1],
         [0],
         [0],
         [0]]])
tensor([[[1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]]])
torch.Size([2, 5, 6])
tensor([[[ 0.3394, -0.5712,  1.8612,  0.1429, -2.4633, -0.9110],
         [ 0.1373,  0.6970, -1.1414,  1.7767,  0.6732,  1.2749],
         [ 0.9844,  1.2328, -0.6156,  0.2858, -0.1262,  0.6503],
         [-0.8435,  0.4985,  1.4154,  1.6093, -1.0917,  0.1000],
         [ 2.0718, -1.2521,  0.1359, -1.3609,  0.6042,  0.7212]],

        [[ 0.3867,  0.4413, -0.1570,  0.6820, -0.4854,  1.6211],
         [-0.3202,  0.7076,  0.8119, -0.9865,  0.3860,  2.0085],
         [ 1.3758, -1.3889,  0.2530,  0.2001,  0.1194,  0.6831],
         [-0.8428,