# Deep Sets example

The dataset are samples from a 12-th order polynomial. The objective of the set will be to learn the polynomials.

In [None]:
## The packages we need
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Define the dataset
class PolynomialDataset(Dataset):
    def __init__(self, num_samples=1000, set_size=10, order=12):
        # Dataset with 1000 samples
        # Each sample contain 10 points: Cannot match them perfectly
        self.num_samples = num_samples
        self.set_size = set_size
        self.order = order
        self.data = []
        self.targets = []

        for _ in range(num_samples):
            coefficients = np.random.randn(order + 1)
            #coefficients of the polynomial including zero order
            
            x = np.random.uniform(-1, 1, set_size)
            # random samples
            
            y = np.polyval(coefficients, x)
            # y-values corresponding to the coefficients
            
            self.data.append(torch.tensor(np.vstack((x, y)).T, dtype=torch.float32))
            self.targets.append(torch.tensor(coefficients, dtype=torch.float32))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]


In [None]:
## We will use the psi to define the model

class DeepSets(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(DeepSets, self).__init__()
        self.psi = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.phi = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        # Apply psi to each element in the set
        psi_x = self.psi(x)
        # Sum the outputs of psi
        summed_psi_x = psi_x.sum(dim=1)
        # Apply phi to the summed outputs
        output = self.phi(summed_psi_x)
        return output

In [None]:
# Hyperparameters
input_dim = 2
hidden_dim = 64
output_dim = 13
num_samples = 1000
set_size = 100
batch_size = 32
learning_rate = 0.001
num_epochs = 50

# Prepare the dataset and dataloader
dataset = PolynomialDataset(num_samples=num_samples, set_size=set_size, order=output_dim-1)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Instantiate the model, loss function, and optimizer
model = DeepSets(input_dim, hidden_dim, output_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for data, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}')

# Evaluation
model.eval()
test_set, true_coefficients = dataset[0]  # Take the first sample from the dataset
test_set = test_set.unsqueeze(0)  # Add batch dimension

with torch.no_grad():
    predicted_coefficients = model(test_set).squeeze().numpy()

print(f'True Coefficients: {true_coefficients.numpy()}')
print(f'Predicted Coefficients: {predicted_coefficients}')

In [None]:
# Generate points for plotting the true and predicted polynomials
x_plot = np.linspace(-1, 1, 400)
true_y_plot = np.polyval(true_coefficients.numpy(), x_plot)
predicted_y_plot = np.polyval(predicted_coefficients, x_plot)

# Plot the points and polynomials
plt.figure(figsize=(10, 6))
plt.scatter(test_set.squeeze()[:, 0].numpy(), test_set.squeeze()[:, 1].numpy(), color='blue', label='Data Points')
plt.plot(x_plot, true_y_plot, color='green', label='True Polynomial', linewidth=2)
plt.plot(x_plot, predicted_y_plot, color='red', label='Predicted Polynomial', linestyle='dashed', linewidth=2)
plt.xlabel('x')
plt.ylabel('y')
plt.title('True vs Predicted Polynomial')
plt.legend()
plt.grid(True)
plt.show()