In [None]:
import torch

# Example alphas tensor
alphas = torch.tensor([[0.1, 0.2, 0.3, 0.4],
                       [0.5, 0.6, 0.7, 0.8],
                       [0.9, 1.0, 1.1, 1.2]])
n_clusters = alphas.size(1)  # number of clusters
alphas_without_j = alphas.unsqueeze(2).repeat(1, 1, n_clusters)
# Unsqueeze adds a new dimension in the selected index for example 2 and it has size 1
print(alphas_without_j)
# The repeat function in PyTorch is used to create a new tensor by repeating the elements of the original tensor along specified dimensions.
# It takes the number of repetitions as arguments for each dimension and returns a new tensor with the specified shape.
# So in our case we repeat in dim 0 by 1, so it remains the same
# We repeat in dim 1 by 1, so it remains the same
# We repeat in dim 1 by n_clusters, so we have one value for each cluster



mask = ~torch.eye(n_clusters, dtype=torch.bool).to(alphas.device)
print(mask)
# This creates an identity matrix of size n_clusters x n_clusters, where the main diagonal (elements at indices i=i) is set to True, and all other elements are set to False.
# This is achieved by specifying the dtype=torch.bool argument, which sets the data type of the tensor to a boolean type (bool).
unsqueezed_mask = mask.unsqueeze(0)
print(mask.unsqueeze(0))
# Unsqueeze adds a new dimension in the selected index for example 0 and it has size 1
# This result to create a batch of bool matrix


alphas_without_self = torch.masked_select(alphas_without_j, mask)
print(alphas_without_self)
alphas_without_self_reshaped = alphas_without_self.view(3, 4-1, 4)
print(alphas_without_self_reshaped)
# This function selects elements from the alphas_without_j tensor based on the mask tensor.
# The mask tensor is a boolean tensor with True values everywhere except on the main diagonal, where it has False values.
# This mask is used to exclude each data point's own alpha value from consideration when computing alphas_without_self.

betas, _ = torch.min(alphas_without_self_reshaped, dim=1)
print(betas)