In [1]:
import torch
import torch.nn.functional as F

# Define example tensors
c1 = torch.tensor([1.0, 0.0])
c2 = torch.tensor([0.0, 1.0])
d1 = torch.tensor([1.0, 0.5])
d2 = torch.tensor([0.5, 1.0])

# Combine into one tensor
feats = torch.stack([c1, c2, d1, d2])

def info_nce_loss(feats, temperature):
    # Calculate cosine similarity
    cos_sim_in = F.cosine_similarity(feats[:, None, :], feats[None, :, :], dim=-1)

    # Create masks
    batch_size = cos_sim_in.size(0)
    self_mask = torch.eye(batch_size, dtype=torch.bool, device=cos_sim_in.device)

    pos_mask = self_mask.roll(shifts=batch_size // 2, dims=0)  # Positive pairs mask
    
    # Mask out self-similarity
    cos_sim_out = cos_sim_in.masked_fill(self_mask, -9e15)  # Mask diagonal elements

    return cos_sim_in, cos_sim_out, self_mask, pos_mask, feats

# Compute and print the matrices
temperature = 0.07  # Temperature is not used in the computation of masks but is needed for the full loss function

cos_sim_in, cos_sim_out, self_mask, pos_mask, feats = info_nce_loss(feats, temperature)

print('Feats:')
print(feats)
print('\nCosine Similarity Matrix (cos_sim_in):')
print(cos_sim_in)
print('\nCosine Similarity Matrix with Self-similarity Masked (cos_sim_out):')
print(cos_sim_out)
print('\nSelf Mask:')
print(self_mask)
print('\nPositive Mask:')
print(pos_mask)


Feats:
tensor([[1.0000, 0.0000],
        [0.0000, 1.0000],
        [1.0000, 0.5000],
        [0.5000, 1.0000]])

Cosine Similarity Matrix (cos_sim_in):
tensor([[1.0000, 0.0000, 0.8944, 0.4472],
        [0.0000, 1.0000, 0.4472, 0.8944],
        [0.8944, 0.4472, 1.0000, 0.8000],
        [0.4472, 0.8944, 0.8000, 1.0000]])

Cosine Similarity Matrix with Self-similarity Masked (cos_sim_out):
tensor([[-9.0000e+15,  0.0000e+00,  8.9443e-01,  4.4721e-01],
        [ 0.0000e+00, -9.0000e+15,  4.4721e-01,  8.9443e-01],
        [ 8.9443e-01,  4.4721e-01, -9.0000e+15,  8.0000e-01],
        [ 4.4721e-01,  8.9443e-01,  8.0000e-01, -9.0000e+15]])

Self Mask:
tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

Positive Mask:
tensor([[False, False,  True, False],
        [False, False, False,  True],
        [ True, False, False, False],
        [False,  True, False, False]])


In [2]:
import torch
import torch.nn.functional as F

# Define example tensors
c1 = torch.tensor([1.0, 0.0])
c2 = torch.tensor([0.0, 1.0])
d1 = torch.tensor([1.0, 0.5])
d2 = torch.tensor([0.5, 1.0])

# Combine into pairs
# In the original code, this represents concatenating augmented images
pair1 = torch.stack([c1, c2])  # First pair (e.g., c1 and c2 as a batch)
pair2 = torch.stack([d1, d2])  # Second pair (e.g., d1 and d2 as a batch)

# Concatenate pairs along the batch dimension
imgs = torch.cat([pair1, pair2], dim=0)  # Resulting tensor will have shape (4, 2)
print(imgs.shape)

def info_nce_loss(imgs, temperature):
    # Encode all images
    # For this example, feats is directly the imgs tensor
    feats = imgs

    # Calculate cosine similarity
    cos_sim_in = F.cosine_similarity(feats[:, None, :], feats[None, :, :], dim=-1)
    print(cos_sim_in.shape)
    print(cos_sim_in.shape[0])
    # Create masks
    batch_size = cos_sim_in.size(0)
    self_mask = torch.eye(batch_size, dtype=torch.bool, device=cos_sim_in.device)
    target = torch.arange(4)
    target[0::2] += 1
    target[1::2] -= 1
    index = target.reshape(4, 1).long()
    ground_truth_labels = torch.zeros(4, 4).long()
    src = torch.ones(4, 4).long()
    ground_truth_labels = torch.scatter(ground_truth_labels, 1, index, src)
    pos_mask = ground_truth_labels.bool()

    #pos_mask = self_mask.roll(shifts=batch_size//2, dims=0)  # Positive pairs mask
    
    # Mask out self-similarity
    cos_sim_out = cos_sim_in.masked_fill(self_mask, -9e15)  # Mask diagonal elements

    return cos_sim_in, cos_sim_out, self_mask, pos_mask, feats

# Compute and print the matrices
temperature = 0.07  # Temperature is not used in the computation of masks but is needed for the full loss function

cos_sim_in, cos_sim_out, self_mask, pos_mask, feats = info_nce_loss(imgs, temperature)

print('Feats:')
print(feats)
print('\nCosine Similarity Matrix (cos_sim_in):')
print(cos_sim_in)
print('\nCosine Similarity Matrix with Self-similarity Masked (cos_sim_out):')
print(cos_sim_out)
print('\nSelf Mask:')
print(self_mask)
print('\nPositive Mask:')
print(pos_mask)


torch.Size([4, 2])
torch.Size([4, 4])
4
Feats:
tensor([[1.0000, 0.0000],
        [0.0000, 1.0000],
        [1.0000, 0.5000],
        [0.5000, 1.0000]])

Cosine Similarity Matrix (cos_sim_in):
tensor([[1.0000, 0.0000, 0.8944, 0.4472],
        [0.0000, 1.0000, 0.4472, 0.8944],
        [0.8944, 0.4472, 1.0000, 0.8000],
        [0.4472, 0.8944, 0.8000, 1.0000]])

Cosine Similarity Matrix with Self-similarity Masked (cos_sim_out):
tensor([[-9.0000e+15,  0.0000e+00,  8.9443e-01,  4.4721e-01],
        [ 0.0000e+00, -9.0000e+15,  4.4721e-01,  8.9443e-01],
        [ 8.9443e-01,  4.4721e-01, -9.0000e+15,  8.0000e-01],
        [ 4.4721e-01,  8.9443e-01,  8.0000e-01, -9.0000e+15]])

Self Mask:
tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

Positive Mask:
tensor([[False,  True, False, False],
        [ True, False, False, False],
        [False, False, False,  True],
        [False, False,  Tr

In [3]:
import torch

# Define 2D tensors
c1 = torch.tensor([[1.0], [2.0]])
c2 = torch.tensor([[3.0], [4.0]])
d1 = torch.tensor([[5.0], [6.0]])
d2 = torch.tensor([[7.0], [8.0]])

# Simulate batch with these tensors
imgs = [c1, c2, d1, d2]



In [4]:
# Concatenate along dim=0
concatenated_dim0 = torch.cat(imgs, dim=0)
print(concatenated_dim0)
concatenated_dim0.shape


tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]])


torch.Size([8, 1])

In [16]:
# Concatenate along dim=1
concatenated_dim1 = torch.cat(imgs, dim=1)
print(concatenated_dim1)
concatenated_dim1.shape


tensor([[1., 3., 5., 7.],
        [2., 4., 6., 8.]])


torch.Size([2, 4])

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
_ = torch.manual_seed(21)

In [22]:
eye = torch.eye(6)
eye, ~eye.bool()

(tensor([[1., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 1.]]),
 tensor([[False,  True,  True,  True,  True,  True],
         [ True, False,  True,  True,  True,  True],
         [ True,  True, False,  True,  True,  True],
         [ True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True, False]]))

In [23]:
x = torch.randn(6, 2)
x

tensor([[-1.0173, -1.6891],
        [-0.5188,  1.1591],
        [ 2.2763,  0.7654],
        [ 2.4068,  0.8689],
        [-0.3305,  0.5863],
        [ 0.1290,  0.2027]])

In [24]:
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
xcs

tensor([[ 1.0000, -0.5711, -0.7620, -0.7762, -0.4929, -0.9997],
        [-0.5711,  1.0000, -0.0964, -0.0743,  0.9957,  0.5507],
        [-0.7620, -0.0964,  1.0000,  0.9998, -0.1878,  0.7777],
        [-0.7762, -0.0743,  0.9998,  1.0000, -0.1660,  0.7914],
        [-0.4929,  0.9957, -0.1878, -0.1660,  1.0000,  0.4713],
        [-0.9997,  0.5507,  0.7777,  0.7914,  0.4713,  1.0000]])

In [25]:
y = xcs.clone()
y[eye.bool()] = float("-inf")
y

tensor([[   -inf, -0.5711, -0.7620, -0.7762, -0.4929, -0.9997],
        [-0.5711,    -inf, -0.0964, -0.0743,  0.9957,  0.5507],
        [-0.7620, -0.0964,    -inf,  0.9998, -0.1878,  0.7777],
        [-0.7762, -0.0743,  0.9998,    -inf, -0.1660,  0.7914],
        [-0.4929,  0.9957, -0.1878, -0.1660,    -inf,  0.4713],
        [-0.9997,  0.5507,  0.7777,  0.7914,  0.4713,    -inf]])

In [26]:
target = torch.arange(6)
target[0::2] += 1
target[1::2] -= 1
target

tensor([1, 0, 3, 2, 5, 4])

In [11]:
index = target.reshape(4, 1).long()
index

tensor([[1],
        [0],
        [3],
        [2]])

In [12]:
ground_truth_labels = torch.zeros(4, 4).long()
src = torch.ones(4, 4).long()
ground_truth_labels = torch.scatter(ground_truth_labels, 1, index, src)
ground_truth_labels

tensor([[0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 1, 0]])

In [15]:
A = [[ True, False, False, False],
     [False,  True, False, False],
     [False, False,  True, False],
     [False, False, False,  True]]
A = torch.tensor(A)


In [19]:
A.roll(shifts=2, dims=0)

tensor([[False, False,  True, False],
        [False, False, False,  True],
        [ True, False, False, False],
        [False,  True, False, False]])

In [27]:
import torch

# Number of images in batch
n = 4

# Create self_mask as an identity matrix of shape (2n, 2n)
self_mask = torch.eye(2 * n)

# Define shift to be n (number of images per pair)
shift = n

# Roll self_mask to create pos_mask
pos_mask = self_mask.roll(shifts=shift, dims=0)

# pos_mask should be the desired ground_truth_labels
ground_truth_labels = pos_mask.long()

print("Self Mask:")
print(self_mask)
print("Positive Mask:")
print(pos_mask)
print("Ground Truth Labels:")
print(ground_truth_labels)


Self Mask:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.]])
Positive Mask:
tensor([[0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.]])
Ground Truth Labels:
tensor([[0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0]])


In [2]:
import torch

In [3]:
target =torch.zeros(8,8)
target

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [4]:
pos_indices = torch.tensor([
    (0, 0), (0, 2), (0, 4),
    (1, 4), (1, 6), (1, 1),
    (2, 3),
    (3, 7),
    (4, 3),
    (7, 6),
])

In [5]:
pos_indices

tensor([[0, 0],
        [0, 2],
        [0, 4],
        [1, 4],
        [1, 6],
        [1, 1],
        [2, 3],
        [3, 7],
        [4, 3],
        [7, 6]])

In [6]:
pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)
print("\nPositive indexes list")
print(pos_indices)



Positive indexes list
tensor([[0, 0],
        [0, 2],
        [0, 4],
        [1, 4],
        [1, 6],
        [1, 1],
        [2, 3],
        [3, 7],
        [4, 3],
        [7, 6],
        [0, 0],
        [1, 1],
        [2, 2],
        [3, 3],
        [4, 4],
        [5, 5],
        [6, 6],
        [7, 7]])


In [7]:
target[pos_indices[:,0], pos_indices[:,1]] = 1
print(f"\nGround Truth labels for positive and negative pairs for BCE Loss")
print(target)


Ground Truth labels for positive and negative pairs for BCE Loss
tensor([[1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])


Multilabel setting

In [16]:
import torch
import torch.nn.functional as F

In [6]:
x = torch.randn(8, 2)
x

tensor([[ 1.4042,  1.1766],
        [ 0.3409, -1.6624],
        [-1.4203, -0.8864],
        [-2.1777,  1.9969],
        [-0.7660,  1.7046],
        [ 0.8047,  1.4360],
        [ 1.4015, -0.7339],
        [ 1.8676, -0.2438]])

In [7]:
assert len(x.size()) == 2

In [8]:
pos_indices = torch.tensor([
    (0, 0), (0, 2), (0, 4),
    (1, 4), (1, 6), (1, 1),
    (2, 3),
    (3, 7),
    (4, 3),
    (7, 6),
])

In [9]:
 pos_indices = torch.cat([
        pos_indices,
        torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),
    ], dim=0)

In [10]:
pos_indices

tensor([[0, 0],
        [0, 2],
        [0, 4],
        [1, 4],
        [1, 6],
        [1, 1],
        [2, 3],
        [3, 7],
        [4, 3],
        [7, 6],
        [0, 0],
        [1, 1],
        [2, 2],
        [3, 3],
        [4, 4],
        [5, 5],
        [6, 6],
        [7, 7]])

In [12]:
target = torch.zeros(x.size(0), x.size(0))
target[pos_indices[:,0], pos_indices[:,1]] = 1.0

In [13]:
target

tensor([[1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])

In [17]:
# Cosine similarity
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
    # Set logit of diagonal element to "inf" signifying complete
    # correlation. sigmoid(inf) = 1.0 so this will work out nicely
    # when computing the Binary Cross Entropy Loss.
xcs[torch.eye(x.size(0)).bool()] = float("inf")

In [18]:
xcs

tensor([[    inf, -0.4752, -0.9903, -0.1309,  0.2717,  0.9350,  0.3811,  0.6769],
        [-0.4752,     inf,  0.3482, -0.8101, -0.9759, -0.7564,  0.6324,  0.3260],
        [-0.9903,  0.3482,     inf,  0.2675, -0.1352, -0.8766, -0.5060, -0.7727],
        [-0.1309, -0.8101,  0.2675,     inf,  0.9186,  0.2293, -0.9665, -0.8183],
        [ 0.2717, -0.9759, -0.1352,  0.9186,     inf,  0.5954, -0.7862, -0.5245],
        [ 0.9350, -0.7564, -0.8766,  0.2293,  0.5954,     inf,  0.0284,  0.3718],
        [ 0.3811,  0.6324, -0.5060, -0.9665, -0.7862,  0.0284,     inf,  0.9385],
        [ 0.6769,  0.3260, -0.7727, -0.8183, -0.5245,  0.3718,  0.9385,     inf]])

In [20]:
temperature = 0.1

In [21]:
loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")
    

In [22]:
loss

tensor([[-0.0000e+00, 8.5972e-03, 9.9029e+00, 2.3916e-01, 6.4005e-02, 9.3493e+00,
         3.8328e+00, 6.7703e+00],
        [8.5972e-03, -0.0000e+00, 3.5126e+00, 3.0312e-04, 9.7589e+00, 5.1866e-04,
         1.7914e-03, 3.2975e+00],
        [5.0031e-05, 3.5126e+00, -0.0000e+00, 6.6665e-02, 2.3011e-01, 1.5598e-04,
         6.3285e-03, 4.4072e-04],
        [2.3916e-01, 3.0312e-04, 2.7412e+00, -0.0000e+00, 9.1855e+00, 2.3892e+00,
         6.3491e-05, 8.1835e+00],
        [2.7806e+00, 5.7782e-05, 2.3011e-01, 1.0253e-04, -0.0000e+00, 5.9562e+00,
         3.8487e-04, 5.2602e-03],
        [9.3493e+00, 5.1866e-04, 1.5598e-04, 2.3892e+00, 5.9562e+00, -0.0000e+00,
         8.4504e-01, 3.7421e+00],
        [3.8328e+00, 6.3256e+00, 6.3285e-03, 6.3491e-05, 3.8487e-04, 8.4504e-01,
         -0.0000e+00, 9.3842e+00],
        [6.7703e+00, 3.2975e+00, 4.4072e-04, 2.7927e-04, 5.2602e-03, 3.7421e+00,
         8.4046e-05, -0.0000e+00]])

In [23]:
target_pos = target.bool()

In [24]:
target_pos

tensor([[ True, False,  True, False,  True, False, False, False],
        [False,  True, False, False,  True, False,  True, False],
        [False, False,  True,  True, False, False, False, False],
        [False, False, False,  True, False, False, False,  True],
        [False, False, False,  True,  True, False, False, False],
        [False, False, False, False, False,  True, False, False],
        [False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False,  True,  True]])

In [26]:
target_neg = ~target_pos

In [27]:
target_neg

tensor([[False,  True, False,  True, False,  True,  True,  True],
        [ True, False,  True,  True, False,  True, False,  True],
        [ True,  True, False, False,  True,  True,  True,  True],
        [ True,  True,  True, False,  True,  True,  True, False],
        [ True,  True,  True, False, False,  True,  True,  True],
        [ True,  True,  True,  True,  True, False,  True,  True],
        [ True,  True,  True,  True,  True,  True, False,  True],
        [ True,  True,  True,  True,  True,  True, False, False]])

In [29]:
loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)

In [32]:
loss_pos

tensor([9.9669e+00, 9.7607e+00, 6.6665e-02, 8.1835e+00, 1.0253e-04, 0.0000e+00,
        0.0000e+00, 8.4046e-05])

In [33]:
loss_neg

tensor([20.2002,  6.8195,  3.7497, 14.5554,  8.9727, 22.2825, 20.3944, 13.8158])

In [34]:
num_pos = target.sum(dim=1)
num_neg = x.size(0) - num_pos

In [35]:
num_pos

tensor([3., 3., 2., 2., 2., 1., 1., 2.])

In [36]:
num_neg

tensor([5., 5., 6., 6., 6., 7., 7., 6.])

In [38]:
((loss_pos / num_pos) + (loss_neg / num_neg)).mean()


tensor(3.6313)

we have to find a way for  creating positive indices then apply this thats it thats our next step

In [3]:
torch.zeros(8,8)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [4]:
import torch

def get_positive_indices_matrix(n, m):
    # Initialize a zero matrix of size (n*m) x (n*m)
    positive_matrix = torch.zeros(n*m, n*m)

    # Loop through each image
    for i in range(m):
        # Calculate the row and column ranges for this image
        start_idx = i * n
        end_idx = (i + 1) * n
        
        # Set the corresponding block to 1 (for positive indices)
        positive_matrix[start_idx:end_idx, start_idx:end_idx] = 1

    return positive_matrix

# Example usage:
n = 4  # Number of augmentations per image
m = 3  # Number of images

positive_indices_matrix = get_positive_indices_matrix(n, m)
print(positive_indices_matrix)


tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]])


link to new info method multi n:
    https://chatgpt.com/share/b08ad2c2-2c12-4519-aa18-9024809f3793

In [5]:
import torch

def get_positive_indices_matrix(n, m):
    # Initialize a zero matrix of size (n*m) x (n*m)
    positive_matrix = torch.zeros(n*m, n*m)

    # Loop through each image
    for i in range(m):
        # Calculate the row and column ranges for this image
        start_idx = i * n
        end_idx = (i + 1) * n
        
        # Set the corresponding block to 1 (for positive indices)
        positive_matrix[start_idx:end_idx, start_idx:end_idx] = 1

    return positive_matrix

# Example usage:
n = 4  # Number of augmentations per image
m = 3  # Number of images

positive_indices_matrix = get_positive_indices_matrix(n, m)
print(positive_indices_matrix)


tensor([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]])
