In [9]:
import torch

# Example data
BLOCK_BK = 5  # Assuming BLOCK_BK is 5 for this example
t_group_size = 4  # Maximum number of group sizes

# Example tensors
indices = torch.tensor([2, 5, 8, 10, 15])  # Shape (BLOCK_BK,)
group_size = torch.tensor([3, 2, 4, 1, 3])  # Shape (BLOCK_BK,)

# Step 1: Create a base tensor with garbage values (-1)
dupped_indices = torch.full((BLOCK_BK, t_group_size), -1, dtype=torch.int64)

print(indices)
print(indices.shape)
print(group_size)
print(group_size.shape)
print(dupped_indices)
print(dupped_indices.shape)

tensor([ 2,  5,  8, 10, 15])
torch.Size([5])
tensor([3, 2, 4, 1, 3])
torch.Size([5])
tensor([[-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1]])
torch.Size([5, 4])


In [7]:
# Step 2: Generate ranges for each group in a flat tensor
group_indices = torch.arange(t_group_size).unsqueeze(0)  # Shape (1, t_group_size)
group_indices = group_indices.expand(BLOCK_BK, t_group_size)  # Shape (BLOCK_BK, t_group_size)

print(group_indices)
print(group_indices.shape)

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])
torch.Size([5, 4])


In [8]:
# Step 3: Create a tensor where the group sizes define valid ranges for each row
group_indices = group_indices + indices.unsqueeze(1)  # Broadcasting addition of indices

# Step 4: Scatter values into dupped_indices according to group_size
scatter_indices = torch.arange(t_group_size).expand(BLOCK_BK, t_group_size)  # Prepare scatter indices
dupped_indices.scatter_(1, scatter_indices, group_indices)

# Step 5: Zero out the invalid entries (group_size < t_group_size)
valid_ranges = torch.arange(t_group_size).unsqueeze(0) < group_size.unsqueeze(1)
dupped_indices[~valid_ranges] = -1  # Fill the invalid positions with -1

print(indices)
print(indices.shape)
print(group_size)
print(group_size.shape)
print(dupped_indices)
print(dupped_indices.shape)

tensor([[ 2,  3,  4,  5],
        [ 5,  6,  7,  8],
        [ 8,  9, 10, 11],
        [10, 11, 12, 13],
        [15, 16, 17, 18]])
torch.Size([5, 4])
tensor([[ 2,  3,  4, -1],
        [ 5,  6, -1, -1],
        [ 8,  9, 10, 11],
        [10, -1, -1, -1],
        [15, 16, 17, -1]])
torch.Size([5, 4])


In [1]:
import torch

# Example data
BLOCK_BK = 5  # Assuming BLOCK_BK is 5 for this example
t_group_size = 4  # Maximum number of group sizes
multi_branch_ratio_per_layer = 3  # Example multi_branch_ratio_per_layer

# Derived size based on multi_branch_ratio_per_layer
output_size = 2 * multi_branch_ratio_per_layer  # Ensure it's larger or equal to t_group_size

# Example tensors
indices = torch.tensor([2, 5, 8, 10, 15])  # Shape (BLOCK_BK,)
group_size = torch.tensor([3, 2, 4, 1, 3])  # Shape (BLOCK_BK,)

# Step 1: Create a base tensor with garbage values (-1)
dupped_indices = torch.full((BLOCK_BK, output_size), -1, dtype=torch.int64)

print(indices)
print(indices.shape)
print(group_size)
print(group_size.shape)
print(dupped_indices)
print(dupped_indices.shape)

tensor([ 2,  5,  8, 10, 15])
torch.Size([5])
tensor([3, 2, 4, 1, 3])
torch.Size([5])
tensor([[-1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1]])
torch.Size([5, 6])


In [2]:
# Step 2: Generate ranges for each group up to the max `group_size`
range_tensor = torch.arange(output_size).unsqueeze(0).expand(BLOCK_BK, output_size)  # Shape (BLOCK_BK, output_size)

# Step 3: Add the base `indices` to the range_tensor
group_indices = indices.unsqueeze(1) + range_tensor

# Step 4: Ensure we don't add more indices than specified by `group_size`
# Create scatter indices for the first `group_size` elements of each row
scatter_indices = torch.arange(output_size).expand(BLOCK_BK, output_size)

# Step 5: Use scatter_ to add only up to the valid `group_size`
valid_mask = scatter_indices < group_size.unsqueeze(1)
dupped_indices.scatter_(1, scatter_indices, torch.where(valid_mask, group_indices, dupped_indices))

print(dupped_indices)

tensor([[ 2,  3,  4, -1, -1, -1],
        [ 5,  6, -1, -1, -1, -1],
        [ 8,  9, 10, 11, -1, -1],
        [10, -1, -1, -1, -1, -1],
        [15, 16, 17, -1, -1, -1]])


In [4]:
torch.iinfo(indices.dtype).max

9223372036854775807

In [7]:
BLOCK_BK = 5  # Assuming BLOCK_BK is 5 for this example
t_group_size = 4  # Maximum number of group sizes
multi_branch_ratio_per_layer = 3  # Example multi_branch_ratio_per_layer

# Derived size based on multi_branch_ratio_per_layer

# Example tensors
indices = torch.tensor([2, 5, 8, 10, 15])  # Shape (BLOCK_BK,)
group_sizes = torch.tensor([3, 2, 4, 1, 3])  # Shape (BLOCK_BK,)

garbage_value = torch.iinfo(indices.dtype).max
M2 = 2 * multi_branch_ratio_per_layer
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value).to(indices.device)
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)  # Shape (1, t_group_size)

group_indices = group_indices + indices.unsqueeze(1)
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

valid_mask = scatter_indices < group_sizes.unsqueeze(1)
dupped_indices.scatter_(1, scatter_indices, torch.where(valid_mask, group_indices, dupped_indices))

print(dupped_indices)

tensor([[                  2,                   3,                   4,
         9223372036854775807, 9223372036854775807, 9223372036854775807],
        [                  5,                   6, 9223372036854775807,
         9223372036854775807, 9223372036854775807, 9223372036854775807],
        [                  8,                   9,                  10,
                          11, 9223372036854775807, 9223372036854775807],
        [                 10, 9223372036854775807, 9223372036854775807,
         9223372036854775807, 9223372036854775807, 9223372036854775807],
        [                 15,                  16,                  17,
         9223372036854775807, 9223372036854775807, 9223372036854775807]])


In [10]:
group_sizes.type

<function Tensor.type>

In [12]:
dupped_group_sizes = torch.ones((BLOCK_BK, M2), dtype=group_sizes.dtype)


print(dupped_group_sizes)

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]])


In [13]:
print(group_indices)

tensor([[ 2,  3,  4,  5,  6,  7],
        [ 5,  6,  7,  8,  9, 10],
        [ 8,  9, 10, 11, 12, 13],
        [10, 11, 12, 13, 14, 15],
        [15, 16, 17, 18, 19, 20]])


In [14]:
dupped_group_sizes.scatter_(1, scatter_indices, torch.where(valid_mask, dupped_group_sizes, dupped_group_sizes))

print(dupped_group_sizes)

tensor([[ 2,  3,  4,  1,  1,  1],
        [ 5,  6,  1,  1,  1,  1],
        [ 8,  9, 10, 11,  1,  1],
        [10,  1,  1,  1,  1,  1],
        [15, 16, 17,  1,  1,  1]])


In [11]:
import torch

# Assuming you have the following variables defined elsewhere in your code
BLOCK_BK = 16  # Example value
M2 = 8  # Threshold for group sizes
multi_branch_ratio_per_layer = 0.75  # Example value

# Example tensors for demonstration purposes
indices = torch.tensor([5, 8, 12, 15, 22, 25, 30, 35], dtype=torch.int32)
group_sizes = torch.tensor([4, 6, 9, 3, 7, 12, 14, 5], dtype=torch.int32)

# Define garbage value
garbage_value = torch.iinfo(indices.dtype).max

# Separate small and large groups
small_group_sizes = group_sizes <= M2
large_group_sizes = group_sizes > M2

# Small group tensors
indices_small = indices[small_group_sizes]
group_sizes_small = group_sizes[small_group_sizes]


print(indices)
print(indices.shape)

print(indices_small)
print(indices_small.shape)

tensor([ 5,  8, 12, 15, 22, 25, 30, 35], dtype=torch.int32)
torch.Size([8])
tensor([ 5,  8, 15, 22, 35], dtype=torch.int32)
torch.Size([5])


In [None]:
# Large group tensors
indices_large = indices[large_group_sizes]
group_sizes_large = group_sizes[large_group_sizes]

# The code operation applied to `indices_small` and `group_sizes_small`
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value).to(indices.device)
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)  # Shape (BLOCK_BK, M2)

group_indices = group_indices + indices_small.unsqueeze(1)
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

valid_mask = scatter_indices < group_sizes_small.unsqueeze(1)

# Update dupped_indices using scatter
dupped_indices.scatter_(1, scatter_indices, torch.where(valid_mask, group_indices, dupped_indices))

# Create dupped_group_sizes and reshape
dupped_group_sizes = torch.ones((BLOCK_BK, M2), dtype=group_sizes_small.dtype)
dupped_indices = dupped_indices.view(BLOCK_BK * M2)
dupped_group_sizes = dupped_group_sizes.view(BLOCK_BK * M2)

# Create mask
mask_bk_dup = (dupped_indices < garbage_value)

# Store dupped_indices and dupped_group_sizes (for tl.store-like behavior, assume some output tensor)
DUPPED_INDICES = torch.empty_like(dupped_indices)
DUPPED_GROUP_SIZE = torch.empty_like(dupped_group_sizes)

# Mocking `tl.store` behavior (this stores the values)
DUPPED_INDICES[mask_bk_dup] = dupped_indices[mask_bk_dup]
DUPPED_GROUP_SIZE[mask_bk_dup] = dupped_group_sizes[mask_bk_dup]

# Now you have `dupped_indices` and `dupped_group_sizes` ready for further operations
print("Dupped Indices:", DUPPED_INDICES)
print("Dupped Group Sizes:", DUPPED_GROUP_SIZE)

# Handle `indices_large` and `group_sizes_large` with your other operations
# (Assume you will apply a different set of operations for large group sizes here)

In [12]:
import torch

# Assuming you have the following variables defined elsewhere in your code
BLOCK_BK = 16  # Example value
M2 = 8  # Threshold for group sizes
multi_branch_ratio_per_layer = 0.75  # Example value

# Example tensors for demonstration purposes
indices = torch.tensor([5, 8, 12, 15, 22, 25, 30, 35], dtype=torch.int32)
group_sizes = torch.tensor([4, 6, 9, 3, 7, 12, 14, 5], dtype=torch.int32)

# Define garbage value
garbage_value = torch.iinfo(indices.dtype).max

# Separate small and large groups
small_group_sizes = group_sizes <= M2
large_group_sizes = group_sizes > M2

# Small group tensors
indices_small = indices[small_group_sizes]
group_sizes_small = group_sizes[small_group_sizes]


print(indices_small)

tensor([ 5,  8, 15, 22, 35], dtype=torch.int32)


In [None]:
# Initialize the dupped_indices with garbage values
dupped_indices = torch.full((BLOCK_BK, len(indices)), garbage_value).to(indices.device)

# Create a group_indices tensor and apply the same operation only for small indices
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)  # Shape (BLOCK_BK, M2)
group_indices = group_indices + indices_small.unsqueeze(1)  # Operate only on small group sizes

scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

# Valid mask for small group sizes
valid_mask = scatter_indices < group_sizes_small.unsqueeze(1)

# Only scatter into `dupped_indices` for positions corresponding to small group sizes
# Use torch.where to select valid indices, or keep the garbage value for invalid ones
dupped_indices[:, small_group_sizes] = torch.where(valid_mask, group_indices, dupped_indices[:, small_group_sizes])

# Create dupped_group_sizes and reshape it
dupped_group_sizes = torch.ones((BLOCK_BK, len(indices)), dtype=group_sizes.dtype).to(indices.device)

# Reshape to match expected output dimensions
dupped_indices = dupped_indices.view(BLOCK_BK * len(indices))
dupped_group_sizes = dupped_group_sizes.view(BLOCK_BK * len(indices))

# Create mask
mask_bk_dup = (dupped_indices < garbage_value)

# Mocking tl.store-like behavior (assuming some output tensors)
DUPPED_INDICES = torch.empty_like(dupped_indices)
DUPPED_GROUP_SIZE = torch.empty_like(dupped_group_sizes)

# Store values in the result tensor
DUPPED_INDICES[mask_bk_dup] = dupped_indices[mask_bk_dup]
DUPPED_GROUP_SIZE[mask_bk_dup] = dupped_group_sizes[mask_bk_dup]

# Now you have `dupped_indices` and `dupped_group_sizes` with the same size as the original tensors,
# while only operating on the small group sizes.
print("Dupped Indices:", DUPPED_INDICES)
print("Dupped Group Sizes:", DUPPED_GROUP_SIZE)

# Handle `indices_large` and `group_sizes_large` with your other operations
# (You can apply other operations for large group sizes here)

In [18]:
BLOCK_BK = 5  # Assuming BLOCK_BK is 5 for this example
t_group_size = 4  # Maximum number of group sizes
multi_branch_ratio_per_layer = 3  # Example multi_branch_ratio_per_layer

# Derived size based on multi_branch_ratio_per_layer

# Example tensors
indices = torch.tensor([2, 5, 8, 10, 15])  # Shape (BLOCK_BK,)
group_sizes = torch.tensor([3, 2, 4, 1, 3])  # Shape (BLOCK_BK,)

garbage_value = -1 # torch.iinfo(indices.dtype).max
M2 = 2 * multi_branch_ratio_per_layer
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value).to(indices.device)
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)  # Shape (1, t_group_size)

group_indices = group_indices + indices.unsqueeze(1)
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

valid_mask = scatter_indices < group_sizes.unsqueeze(1)
dupped_indices.scatter_(1, scatter_indices, torch.where(valid_mask, group_indices, dupped_indices))

print(indices)
print(group_sizes)
print(group_indices)
print(scatter_indices)
print('---------')
print(dupped_indices)

tensor([ 2,  5,  8, 10, 15])
tensor([3, 2, 4, 1, 3])
tensor([[ 2,  3,  4,  5,  6,  7],
        [ 5,  6,  7,  8,  9, 10],
        [ 8,  9, 10, 11, 12, 13],
        [10, 11, 12, 13, 14, 15],
        [15, 16, 17, 18, 19, 20]])
tensor([[0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5]])
---------
tensor([[ 2,  3,  4, -1, -1, -1],
        [ 5,  6, -1, -1, -1, -1],
        [ 8,  9, 10, 11, -1, -1],
        [10, -1, -1, -1, -1, -1],
        [15, 16, 17, -1, -1, -1]])


In [52]:
import torch

# Example data
BLOCK_BK = 5  # Assuming BLOCK_BK is 5 for this example
M2 = 4  # Threshold value

# Example tensors
indices = torch.tensor([2, 5, 8, 10, 15], dtype=torch.int64)  # Shape (BLOCK_BK,)
group_sizes = torch.tensor([3, 2, 5, 1, 4], dtype=torch.int64)  # Shape (BLOCK_BK,)

# Step 1: Define the garbage value
garbage_value = -1 # torch.iinfo(indices.dtype).max

# Step 2: Initialize the dupped_indices tensor with garbage values
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value).to(indices.device)

# Step 3: Create the group_indices tensor
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)
group_indices = group_indices + indices.unsqueeze(1)

# Step 4: Create scatter indices
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

# Step 5: Create a mask where group_sizes <= M2
mask = group_sizes.unsqueeze(1) <= M2

# Step 6: For the valid mask, apply the operation using torch.where
valid_mask = scatter_indices < group_sizes.unsqueeze(1)
# dupped_indices.scatter_(1, scatter_indices, torch.where(valid_mask & mask, group_indices, dupped_indices))
dupped_indices = torch.where(valid_mask & mask, group_indices, dupped_indices)

# Step 7: Create dupped_group_sizes
dupped_group_sizes = torch.where(mask, torch.ones_like(group_sizes).unsqueeze(1).expand(BLOCK_BK, M2),
                                 group_sizes.unsqueeze(1).expand(BLOCK_BK, M2))

dupped_indices = torch.where(mask, dupped_indices, indices.unsqueeze(1).expand(BLOCK_BK, M2))

# Step 8: Reshape the tensors as requested
# dupped_indices = dupped_indices.reshape(BLOCK_BK * M2)
# dupped_group_sizes = dupped_group_sizes.reshape(BLOCK_BK * M2)

# Print the output tensors
print('indices', indices)
print('group_sizes', group_sizes)
print("dupped_indices:", dupped_indices)
print("dupped_group_sizes:", dupped_group_sizes)
print('valid_mask', valid_mask)
print('mask', mask)
print('mask_final', mask & valid_mask)
print(mask.shape)

indices tensor([ 2,  5,  8, 10, 15])
group_sizes tensor([3, 2, 5, 1, 4])
dupped_indices: tensor([[ 2,  3,  4, -1],
        [ 5,  6, -1, -1],
        [ 8,  8,  8,  8],
        [10, -1, -1, -1],
        [15, 16, 17, 18]])
dupped_group_sizes: tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [5, 5, 5, 5],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
valid_mask tensor([[ True,  True,  True, False],
        [ True,  True, False, False],
        [ True,  True,  True,  True],
        [ True, False, False, False],
        [ True,  True,  True,  True]])
mask tensor([[ True],
        [ True],
        [False],
        [ True],
        [ True]])
mask_final tensor([[ True,  True,  True, False],
        [ True,  True, False, False],
        [False, False, False, False],
        [ True, False, False, False],
        [ True,  True,  True,  True]])
torch.Size([5, 1])


In [54]:
mask.shape

torch.Size([5, 1])

In [53]:
~mask

tensor([[False],
        [False],
        [ True],
        [False],
        [False]])

In [31]:
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)
print(group_indices)
print(group_indices.shape)

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])
torch.Size([5, 4])


In [32]:
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

print(scatter_indices)
print(scatter_indices.shape)

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])
torch.Size([5, 4])


In [46]:
import torch

# Example values for demonstration
BLOCK_BK = 5  # Example BLOCK_BK value
M2 = 4  # Example M2 value
indices = torch.tensor([2, 5, 8, 10, 15], dtype=torch.int64)
group_sizes = torch.tensor([3, 2, 4, 1, 4], dtype=torch.int64)

# Step 1: Define the garbage value
garbage_value = -1 # torch.iinfo(indices.dtype).max

# Step 2: Initialize the dupped_indices tensor with garbage values
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value).to(indices.device)

# Step 3: Create the group_indices tensor
group_indices = torch.arange(M2).unsqueeze(0).expand(BLOCK_BK, M2).to(indices.device)
group_indices = group_indices + indices.unsqueeze(1)

# Step 4: Create scatter indices
scatter_indices = torch.arange(M2).expand(BLOCK_BK, M2).to(indices.device)

# Step 5: Create a valid mask
valid_mask = scatter_indices < group_sizes.unsqueeze(1)

# Step 6: Apply the logic without using scatter_
# Instead of scatter_, we will directly assign values where `valid_mask` is True
dupped_indices = torch.where(valid_mask, group_indices, dupped_indices)

# Step 7: Create dupped_group_sizes
dupped_group_sizes = torch.ones((BLOCK_BK, M2), dtype=group_sizes.dtype)

# Step 8: Reshape dupped_indices and dupped_group_sizes
# dupped_indices = dupped_indices.view(BLOCK_BK * M2)
# dupped_group_sizes = dupped_group_sizes.view(BLOCK_BK * M2)

# Step 9: Create mask_bk_dup to filter out garbage values
mask_bk_dup = dupped_indices < garbage_value

# Output the results
print('indices', indices)
print('group_sizes', group_sizes)
print('group_indices', group_indices)
print("Dupped Indices:", dupped_indices)
print("Dupped Group Sizes:", dupped_group_sizes)
print("Mask BK Dup", mask_bk_dup)


indices tensor([ 2,  5,  8, 10, 15])
group_sizes tensor([3, 2, 4, 1, 4])
group_indices tensor([[ 2,  3,  4,  5],
        [ 5,  6,  7,  8],
        [ 8,  9, 10, 11],
        [10, 11, 12, 13],
        [15, 16, 17, 18]])
Dupped Indices: tensor([[ 2,  3,  4, -1],
        [ 5,  6, -1, -1],
        [ 8,  9, 10, 11],
        [10, -1, -1, -1],
        [15, 16, 17, 18]])
Dupped Group Sizes: tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
Mask BK Dup tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])


In [3]:
import torch

# Example values for demonstration
BLOCK_BK = 5  # Example BLOCK_BK value
M2 = 4  # Example M2 value
indices = torch.tensor([2, 5, 8, 10, 15], dtype=torch.int64)
group_sizes = torch.tensor([3, 2, 4, 1, 4], dtype=torch.int64)

# Set constants
garbage_value = -1 # MAX_TSRC + 1

# Create tensors with appropriate shapes and data types
dupped_indices = torch.full((BLOCK_BK, M2), garbage_value, dtype=indices.dtype)
group_indices = torch.arange(0, M2).unsqueeze(0).expand(BLOCK_BK, M2)  # Shape (BLOCK_BK, M2)

# Add indices to group_indices along a new dimension
group_indices = group_indices + indices.unsqueeze(1)

# Scatter indices and create the valid mask
scatter_indices = torch.arange(0, M2).unsqueeze(0).expand(BLOCK_BK, M2)
valid_mask = scatter_indices < group_sizes.unsqueeze(1)

# Apply the valid mask using torch.where
dupped_indices = torch.where(valid_mask, group_indices, dupped_indices)

# Initialize duplicated group sizes with ones
dupped_group_sizes = torch.ones((BLOCK_BK, M2), dtype=group_sizes.dtype)

# Reshape the indices and group sizes to flatten the first two dimensions
# dupped_indices = dupped_indices.reshape(BLOCK_BK * M2)
# dupped_group_sizes = dupped_group_sizes.reshape(BLOCK_BK * M2)

print(dupped_indices)
print(dupped_group_sizes)

tensor([[ 2,  3,  4, -1],
        [ 5,  6, -1, -1],
        [ 8,  9, 10, 11],
        [10, -1, -1, -1],
        [15, 16, 17, 18]])
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])


In [4]:
import torch

# Example values for demonstration
BLOCK_BK = 5  # Example BLOCK_BK value
M2 = 4  # Example M2 value
indices = torch.tensor([2, 5, 8, 10, 15], dtype=torch.int64)
group_sizes = torch.tensor([3, 2, 4, 1, 4], dtype=torch.int64)

# Create tensors with appropriate shapes and data types
# Instead of using the garbage_value, initialize with indices + group_size - 1
initial_value = indices.unsqueeze(1) + group_sizes.unsqueeze(1) - 1
dupped_indices = initial_value.expand(BLOCK_BK, M2).clone()

# Create group_indices
group_indices = torch.arange(0, M2).unsqueeze(0).expand(BLOCK_BK, M2)  # Shape (BLOCK_BK, M2)

# Add indices to group_indices along a new dimension
group_indices = group_indices + indices.unsqueeze(1)

# Scatter indices and create the valid mask
scatter_indices = torch.arange(0, M2).unsqueeze(0).expand(BLOCK_BK, M2)
valid_mask = scatter_indices < group_sizes.unsqueeze(1)

# Apply the valid mask using torch.where
dupped_indices = torch.where(valid_mask, group_indices, dupped_indices)

# Initialize duplicated group sizes with ones
dupped_group_sizes = torch.ones((BLOCK_BK, M2), dtype=group_sizes.dtype)

# Print the result
print("Dupped Indices:")
print(dupped_indices)
print("\nDupped Group Sizes:")
print(dupped_group_sizes)


Dupped Indices:
tensor([[ 2,  3,  4,  4],
        [ 5,  6,  6,  6],
        [ 8,  9, 10, 11],
        [10, 10, 10, 10],
        [15, 16, 17, 18]])

Dupped Group Sizes:
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
