In [128]:
import torch
import torch.nn as nn

In [105]:
def get_pairs(x):
    """
    Returns all pairs of objects for each batch. 
    Output shape will be [batch, n, n, d], corresponding to all pairs of objects in the batch.
    (For each batch b, output[b, i, j] = concat of x[b, i] and x[b, j])
    """
    batch_size, n, d = x.shape
    # Expand so x1 is [batch, n, n, d] with x1[b, i, j] == x[b, i]
    x1 = x.unsqueeze(2).expand(batch_size, n, n, d)
    # x2 is [batch, n, n, d] with x2[b, i, j] == x[b, j]
    x2 = x.unsqueeze(1).expand(batch_size, n, n, d)
    # Concatenate along last dimension: [batch, n, n, 2*d]
    return torch.cat([x1, x2], dim=-1)



In [106]:
class Dense4(nn.Module):
    def __init__(self, in_features=32, out_features=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_features)
        )

    def forward(self, x):
        x = self.net(x)
        return x

In [107]:
class DensSkip4(nn.Module):
    def __init__(self, in_features=32, out_features=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_features)
        )

    def forward(self, x):
        x = x + self.net(x)
        return x

In [108]:
class PermutationLayer(nn.Module):
    def __init__(self, net: nn.Module, pool_fn = torch.mean):
        super().__init__()
        self.net = net
        self.pool_fn = pool_fn

    def forward(self, x):
        x = get_pairs(x)
        x = self.net(x)
        x = self.pool_fn(x, dim=2) 
        return x


In [121]:
class Perm31(nn.Module):
    def __init__(self, features=4, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            PermutationLayer(nn.Sequential(
                nn.Linear(2*features, hidden_dim), 
                nn.ReLU()
            )),
            PermutationLayer(nn.Sequential(
                nn.Linear(2*hidden_dim, hidden_dim), 
                nn.ReLU()
            )),
            PermutationLayer(nn.Sequential(
                nn.Linear(2*hidden_dim, features), 
            )),
        )
        

    def forward(self, x):
        return self.net(x)

In [122]:
l = Perm31()
x0 = torch.rand(1, 10, 4)
x = l(x0)
x.shape

torch.Size([1, 10, 4])

In [126]:
class PermSkip34(nn.Module):
    def __init__(self, features=4, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            PermutationLayer(nn.Sequential(
                nn.Linear(2*features, hidden_dim), 
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
            )),
            PermutationLayer(nn.Sequential(
                nn.Linear(2*hidden_dim, hidden_dim), 
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
            )),
            PermutationLayer(nn.Sequential(
                nn.Linear(2*hidden_dim, hidden_dim), 
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, features),
            )),

        )
        

    def forward(self, x):
        return x + self.net(x)

In [127]:
l = PermSkip34()
x0 = torch.rand(100, 10, 4)
x = l(x0)
x.shape

torch.Size([100, 10, 4])