In [51]:
import torch

# Set small dimensions for clarity
B = 4  # batch size
L = 2  # so J has L+1=3 matrices and S has L+3=5 rows
N = 10  # feature dimension

# Create J and S with sequential numbers
J = torch.arange((L + 1) * N * (3 * N), dtype=torch.float).reshape(L + 1, N, 3 * N)
S = torch.arange(B * (L + 3) * N, dtype=torch.float).reshape(B, L + 3, N)

print(f"J: shape {list(J.shape)}")
# print(J)
print(f"\nS: shape {list(S.shape)}")
# print(S)

assert J.shape == (L + 1, N, 3 * N), f"J shape mismatch: {J.shape}"
assert S.shape == (B, L + 3, N), f"S shape mismatch: {S.shape}"

J: shape [3, 10, 30]

S: shape [4, 5, 10]


In [52]:
S_unfolded = S.unfold(1, 3, 1).transpose(-2, -1).flatten(2)  # Shape: (B, L+1, 3*N)

print(f"\nS_unfolded: shape {list(S_unfolded.shape)}")
# print(S_unfolded)

assert S_unfolded.shape == (B, L + 1, 3 * N), (
    f"S_unfolded shape mismatch: {S_unfolded.shape}"
)


S_unfolded: shape [4, 3, 30]


In [53]:
# batched matrix-vector multiplication of J and S_unfolded.
# J: (L+1, N, 3*N) * S_unfolded: (B, L+1, 3*N) -> result: (B, L+1, N)

result = torch.einsum("lni,bli->bln", J, S_unfolded)
print(f"\nresult: shape {list(result.shape)}")
# print(result)

assert result.shape == (B, L + 1, N), f"result shape mismatch: {result.shape}"


result: shape [4, 3, 10]


In [56]:
b, l = 2, 1

result[b, l, :] == J[l, :, :] @ S_unfolded[b, l, :]

tensor([True, True, True, True, True, True, True, True, True, True])

## Debug

In [1]:
import torch

from src.batch_me_if_u_can import BatchMeIfUCan

# Global parameters
NUM_LAYERS = 3
N = 64
C = 10
J_D = 0.5
LAMBDA_LEFT = [2.0] * (NUM_LAYERS + 1)
LAMBDA_RIGHT = [3.0] * (NUM_LAYERS + 1)
DEVICE = "cpu"
SEED = 42

# Instantiate the class
classifier = BatchMeIfUCan(
    num_layers=NUM_LAYERS,
    N=N,
    C=C,
    J_D=J_D,
    lambda_left=LAMBDA_LEFT,
    lambda_right=LAMBDA_RIGHT,
    device=DEVICE,
    seed=SEED,
)

In [2]:
J = classifier.couplings
J[0].shape, len(J)

(torch.Size([64, 192]), 4)

In [3]:
i = -2
j = 2

J[i, :, j * N : (j + 1) * N]

tensor([[-0.3162,  0.3162, -0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3162,  0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000]])

In [6]:
state = classifier.initialize_state(4, torch.ones(4, N), 2 * torch.ones(4, C))
state.shape

torch.Size([4, 6, 64])

In [18]:
b = 0
i = -2
state[b, i]

tensor([-1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])