In [1]:
import torch

In [2]:
torch.manual_seed(0)

tensor = torch.zeros(6,6)
print(f"tensor:\n{tensor}")

other_tensor = torch.rand(6,2)
print(f"other_tensor:\n{other_tensor}")

masks = torch.tensor([0,2])
print(f"masks: {masks}")

alpha = 0.1

full_other_tensor = torch.zeros(6,6)
for i,j in enumerate(masks):
	full_other_tensor[:,j] = other_tensor[:,i]
print(f"full_other_tensor:\n{full_other_tensor}")
expected_result = tensor + alpha * (full_other_tensor + full_other_tensor.T)

tensor.index_add_(1, masks, other_tensor, alpha=alpha)
print(f"tensor after first add:\n{tensor}")


# Add tensor[not_masks, masks] in place to tensor at [masks, not_masks], 
# to make it symmetric by modifying just the necessary
all_indices = torch.arange(tensor.shape[0])
not_masks_bool = torch.ones(tensor.shape[0], dtype=torch.bool)
not_masks_bool[masks] = 0
not_masks = all_indices[not_masks_bool]

tensor[masks[:,None], not_masks] = tensor[not_masks, masks[:, None]]
# tensor[masks[:,None], not_masks] = tensor[not_masks[:,None], masks].T
print(f"tensor after copy:\n{tensor}")

tensor[masks[:,None], masks] += alpha * other_tensor[masks,:].T
print(f"tensor after add diag:\n{tensor}")

print(f"expected result:\n{expected_result}")
assert torch.allclose(tensor, expected_result)



tensor:
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
other_tensor:
tensor([[0.4963, 0.7682],
        [0.0885, 0.1320],
        [0.3074, 0.6341],
        [0.4901, 0.8964],
        [0.4556, 0.6323],
        [0.3489, 0.4017]])
masks: tensor([0, 2])
full_other_tensor:
tensor([[0.4963, 0.0000, 0.7682, 0.0000, 0.0000, 0.0000],
        [0.0885, 0.0000, 0.1320, 0.0000, 0.0000, 0.0000],
        [0.3074, 0.0000, 0.6341, 0.0000, 0.0000, 0.0000],
        [0.4901, 0.0000, 0.8964, 0.0000, 0.0000, 0.0000],
        [0.4556, 0.0000, 0.6323, 0.0000, 0.0000, 0.0000],
        [0.3489, 0.0000, 0.4017, 0.0000, 0.0000, 0.0000]])
tensor after first add:
tensor([[0.0496, 0.0000, 0.0768, 0.0000, 0.0000, 0.0000],
        [0.0088, 0.0000, 0.0132, 0.0000, 0.0000, 0.0000],
        [0.0307, 0.0000, 0.0634, 0.0000, 0.0000, 0.0000],
        [0.0490, 0.0000, 0.0

In [4]:
diag_bloc = tensor[masks[:,None],masks]
tensor[masks[:,None],masks] = (diag_bloc + diag_bloc)/2
assert torch.allclose(tensor, expected_result)