In [72]:
import torch
from torch.utils.data import DataLoader, Dataset

In [73]:
NUM_SAMPLES = 1000

POINT_DIM = 3
NUM_POINTS = 5

INPUT_DIM = POINT_DIM * NUM_POINTS
OUTPUT_DIM = 5

BATCH_SIZE = 32
NUM_EPOCHS = 10000
LEARNING_RATE = 0.02

In [74]:
points = torch.rand(NUM_SAMPLES, NUM_POINTS, POINT_DIM)
print(points.shape)

labels = torch.randint(0, 2, (NUM_SAMPLES, OUTPUT_DIM))
print(labels.shape)

torch.Size([1000, 5, 3])
torch.Size([1000, 5])


In [75]:
class PermutedPointsDataset(Dataset):
    def __init__(self, points, labels):
        """
        Args:
            points (Tensor): The points tensor of shape (NUM_SAMPLES, NUM_POINTS, POINT_DIM).
            labels (Tensor): The labels tensor of shape (NUM_SAMPLES, OUTPUT_DIM).
        """
        self.points = points
        self.labels = labels

    def __len__(self):
        return len(self.points)

    def __getitem__(self, idx):
        current_points = self.points[idx]
        # Generate a permutation
        perm = torch.randperm(current_points.size(0))
        # Apply the permutation to points
        permuted_points = current_points[perm]
        return permuted_points.view(-1), self.labels[idx]

# Instantiate the custom dataset
dataset = PermutedPointsDataset(points, labels)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [76]:
class SimpleLinearNetwork(torch.nn.Module):
    def __init__(self, layer_widths):
        super(SimpleLinearNetwork, self).__init__()
        self.layers = torch.nn.ModuleList()
        for a,b in zip(layer_widths, layer_widths[1:]):
            self.layers.append(torch.nn.Linear(a,b))
    
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)

In [77]:
model = SimpleLinearNetwork([INPUT_DIM, 100, 100, OUTPUT_DIM])

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for i, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss}")

Epoch 0, Loss: 22.29525649547577
Epoch 100, Loss: 22.180070638656616
Epoch 200, Loss: 22.177550673484802
Epoch 300, Loss: 22.17050838470459
Epoch 400, Loss: 22.161996364593506
Epoch 500, Loss: 22.17396253347397
Epoch 600, Loss: 22.16419118642807
Epoch 700, Loss: 22.180211007595062
Epoch 800, Loss: 22.159494817256927
Epoch 900, Loss: 22.171850621700287
Epoch 1000, Loss: 22.169557094573975
Epoch 1100, Loss: 22.168573439121246
Epoch 1200, Loss: 22.187703251838684
Epoch 1300, Loss: 22.169793784618378
Epoch 1400, Loss: 22.171943366527557
Epoch 1500, Loss: 22.169249713420868
Epoch 1600, Loss: 22.185550212860107
Epoch 1700, Loss: 22.186049163341522
Epoch 1800, Loss: 22.167471945285797
Epoch 1900, Loss: 22.173479437828064
Epoch 2000, Loss: 22.17038232088089
Epoch 2100, Loss: 22.18236756324768
Epoch 2200, Loss: 22.17024201154709
Epoch 2300, Loss: 22.182519674301147
