In [100]:
import torch
from torch.utils.data import DataLoader, Dataset

In [101]:
NUM_SAMPLES = 1000

POINT_DIM = 3
NUM_POINTS = 5

INPUT_DIM = POINT_DIM * NUM_POINTS
OUTPUT_DIM = 5

BATCH_SIZE = 32
NUM_EPOCHS = 10000
LEARNING_RATE = 0.01

In [102]:
points = torch.rand(NUM_SAMPLES, NUM_POINTS, POINT_DIM)
print(points.shape)

labels = torch.randint(0, 2, (NUM_SAMPLES, OUTPUT_DIM))
print(labels.shape)

print(labels[0])

torch.Size([1000, 5, 3])
torch.Size([1000, 5])
tensor([1, 1, 1, 1, 1])


In [103]:
class PermutedPointsDataset(Dataset):
    def __init__(self, points, labels):
        """
        Args:
            points (Tensor): The points tensor of shape (NUM_SAMPLES, NUM_POINTS, POINT_DIM).
            labels (Tensor): The labels tensor of shape (NUM_SAMPLES, OUTPUT_DIM).
        """
        self.points = points
        self.labels = labels

    def __len__(self):
        return len(self.points)

    def __getitem__(self, idx):
        current_points = self.points[idx]
        # Generate a permutation
        perm = torch.randperm(current_points.size(0))
        # Apply the permutation to points
        permuted_points = current_points[perm]
        return permuted_points.view(-1), self.labels[idx]

# Instantiate the custom dataset
dataset = PermutedPointsDataset(points, labels)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [104]:
class SimpleLinearNetwork(torch.nn.Module):
    def __init__(self, layer_widths):
        super(SimpleLinearNetwork, self).__init__()
        self.layers = torch.nn.ModuleList()
        for a,b in zip(layer_widths, layer_widths[1:]):
            self.layers.append(torch.nn.Linear(a,b))
    
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)

In [105]:
model = SimpleLinearNetwork([INPUT_DIM, 100, OUTPUT_DIM])

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for i, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss}")

Epoch 0, Loss: 22.320633351802826
Epoch 100, Loss: 22.11175447702408
Epoch 200, Loss: 22.116070926189423
Epoch 300, Loss: 22.07632404565811
Epoch 400, Loss: 22.086860239505768
Epoch 500, Loss: 22.096828997135162
Epoch 600, Loss: 22.091706454753876
Epoch 700, Loss: 22.05479669570923
Epoch 800, Loss: 22.063742101192474


KeyboardInterrupt: 