In [1]:
import torch
from torch import nn

In [46]:
class MLinear(nn.Module):
    letters = "abcdefghijklmopqrstuvwxyz" # all letters except n

    def _index_generator(self, num_indices: int | None=None):
        for c in (self.available_indexing_names[:num_indices] if num_indices is not None else self.available_indexing_names):
            yield c

    def __init__(self, N: int, *S: int):
        """class for computing a linear transformation between $M$ vectors, where vector i lives in S[i]-dimensional space
        Args:
            N (int): the dimension of the output
            *S (int): the dimensions of the inputs. Should be of length $M$
        """
        super().__init__()

        self.S = S
        self.M = len(S)
        self.N = N

        self.Z = nn.Parameter(torch.zeros(N, *S)) # (N, S_{1}, ..., S_{M})

        self.summing_names = self.letters[-self.M:]
        self.available_indexing_names = self.letters[:-self.M]

    def forward(self, *X: torch.Tensor):
        """compute linear transformation given the M tensors in N
        Args:
            *X (torch.Tensor): sequence of M tensors, where X[i].shape[-1] == S[i], and X[i].shape[:-1]
        Returns:
            torch.Tensor: _description_
        """
        total_indexing_dim = sum(map(lambda x: x.dim() - 1, X))
        i_gen = self._index_generator()

        z_indices = "n" + self.summing_names
        x_indices = ",".join(["".join([next(i_gen) for j in range(m.dim() - 1)]) + self.summing_names[i] for i, m in enumerate(X)])
        r_indices = self.available_indexing_names[:total_indexing_dim] + "n"

        return torch.einsum(f"{z_indices},{x_indices}->{r_indices}", self.Z, *X)

In [71]:
class MAffine(nn.Module):
    @classmethod
    def _combinations(cls, n, c):
        # Initialize the first combination (lexicographically smallest)
        combination = torch.arange(c)

        while combination[0] < n - c + 1:
            yield combination

            # Find the rightmost element that can be incremented
            j = c - 1
            while j >= 0 and combination[j] == n - c + j:
                j -= 1

            # Increment the rightmost element that can be incremented
            combination[j] += 1

            # Adjust the elements to the right
            for k in range(j + 1, c):
                combination[k] = combination[k - 1] + 1

    def __init__(self, N: int, *S: int):
        """class for computing an affine transformation between $M$ vectors, where vector i lives in S[i]-dimensional space
        Args:
            N (int): the dimension of the output
            *S (int): the dimensions of the inputs. Should be of length $M$
        """

        super().__init__()

        self.S = torch.tensor(S, dtype=torch.long)
        self.M = len(S)
        self.N = N
        
        self.b = nn.Parameter(torch.zeros(self.M)) # the bias term

        self.linears = nn.ModuleDict()

        for k in range(self.M):
            for comb in self._combinations(self.M, k + 1):
                S_subset = self.S.gather(index=comb, dim=-1)
                self.linears[str(k + 1)] = MLinear(10, *S_subset.numpy())

    def forward(self, *X: torch.Tensor):
        raise NotImplementedError("Sorry!") # TODO
        # make sure you broadcast things together!

In [73]:
a = MAffine(10, 2, 3, 4)

In [49]:
a.forward(torch.randn(2), torch.randn(3), torch.randn(4)).shape

torch.Size([10])

In [54]:
a = torch.tensor([1,2,3])

In [58]:
[*a.numpy()]

[1, 2, 3]