In [1]:
import torch

# Set small dimensions for clarity
B = 4  # batch size
L = 2  # so J has L+1=3 matrices and S has L+3=5 rows
N = 10  # feature dimension

# Create J and S with sequential numbers
J = torch.arange((L + 1) * N * (3 * N), dtype=torch.float).reshape(L + 1, N, 3 * N)
S = torch.arange(B * (L + 3) * N, dtype=torch.float).reshape(B, L + 3, N)

print(f"J: shape {list(J.shape)}")
# print(J)
print(f"\nS: shape {list(S.shape)}")
# print(S)

assert J.shape == (L + 1, N, 3 * N), f"J shape mismatch: {J.shape}"
assert S.shape == (B, L + 3, N), f"S shape mismatch: {S.shape}"

J: shape [3, 10, 30]

S: shape [4, 5, 10]


In [2]:
S_unfolded = S.unfold(1, 3, 1).transpose(-2, -1).flatten(2)  # Shape: (B, L+1, 3*N)

print(f"\nS_unfolded: shape {list(S_unfolded.shape)}")
# print(S_unfolded)

assert S_unfolded.shape == (B, L + 1, 3 * N), (
    f"S_unfolded shape mismatch: {S_unfolded.shape}"
)


S_unfolded: shape [4, 3, 30]


In [3]:
# batched matrix-vector multiplication of J and S_unfolded.
# J: (L+1, N, 3*N) * S_unfolded: (B, L+1, 3*N) -> result: (B, L+1, N)

result = torch.einsum("lni,bli->bln", J, S_unfolded)
print(f"\nresult: shape {list(result.shape)}")
# print(result)

assert result.shape == (B, L + 1, N), f"result shape mismatch: {result.shape}"


result: shape [4, 3, 10]


In [6]:
b, l = 0, -1

result[b, l, :] == J[l, :, :] @ S_unfolded[b, l, :]

tensor([True, True, True, True, True, True, True, True, True, True])

## Debug

In [None]:
import torch

from src.batch_me_if_u_can import BatchMeIfUCan

# Global parameters
L = 5
N = 1000
C = 10
J_D = 0.2
LAMBDA_LEFT = [0, 2.0, 2.0, 2.0, 2.0, 1.0]
LAMBDA_RIGHT = [4.0, 4.0, 4.0, 4.0, 1.0, 4.0]
DEVICE = "cpu"
SEED = 42
lr = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.01, 0.01])
threshold = torch.tensor([2.5, 2.5, 2.5, 2.5, 2.5, 2.5])
weight_decay = torch.tensor([0.001, 0.001, 0.001, 0.001, 0.001, 0.0, 0.0])

# Instantiate the class
classifier = BatchMeIfUCan(
    num_layers=L,
    N=N,
    C=C,
    J_D=J_D,
    lambda_left=LAMBDA_LEFT,
    lambda_right=LAMBDA_RIGHT,
    device=DEVICE,
    seed=SEED,
    lr=lr,
    threshold=threshold,
    weight_decay=weight_decay,
)

In [2]:
J = classifier.couplings
J[0].shape, len(J)

(torch.Size([64, 192]), 4)

In [3]:
i = -2
j = 2

J[i, :, j * N : (j + 1) * N]

tensor([[-0.3162,  0.3162, -0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3162,  0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3162, -0.3162,  0.3162,  ...,  0.0000,  0.0000,  0.0000]])

In [20]:
B = 4
state = classifier.initialize_state(B, torch.ones(4, N), 2 * torch.ones(4, C))
state.shape

torch.Size([4, 6, 64])

In [21]:
b = 0
i = -2
state[b, i]

tensor([-1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

In [22]:
field = classifier.fields(state)

In [23]:
J = classifier.couplings.clone()
J.shape

torch.Size([4, 64, 192])

In [31]:
## Perceptron Rule

In [38]:
state = torch.arange(B * (L + 3) * N, dtype=torch.float).reshape(B, L + 3, N)

In [41]:
neurons = state[:, 1:-1, :]  # B, L+1, N
S_unfolded = state.unfold(1, 3, 1).transpose(-2, -1)  # B, L+1, 3, N

In [43]:
result = torch.empty(B, L + 1, 3, N, N, dtype=neurons.dtype)

for b in range(B):
    for l in range(L + 1):
        for c in range(3):
            # Compute outer product: shape (N, N)
            result[b, l, c] = neurons[b, l].unsqueeze(1) * S_unfolded[
                b, l, c
            ].unsqueeze(0)

In [55]:
out = torch.einsum("bli,blcj->licj", neurons, S_unfolded).flatten(2)

In [56]:
out.shape

torch.Size([4, 64, 192])

In [66]:
l = 0
c = 2
result[:, l, c].sum(dim=0) == out[l, :, c * N : (c + 1) * N]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])