In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

## data prep

In [3]:


# --- Parameters ---
modulus = 17   # N in mod N
m = 3          # multiplier
b = 5          # bias term
n_samples = 1000  # number of training examples

# --- Generate the data ---
# x will be integers 0,1,...,n_samples-1.
x = np.arange(n_samples)
# f(x) = (m*x + b) mod modulus
y = (m * x + b) % modulus

# Convert to torch tensors.
# For the network we use float inputs and integer targets (for cross-entropy).
x_train = torch.tensor(x, dtype=torch.float32).unsqueeze(1)  # shape: [n_samples, 1]
y_train = torch.tensor(y, dtype=torch.long)                  # shape: [n_samples]


## simple neural network

In [7]:

# --- Define a Simple MLP ---
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

model = SimpleMLP(input_dim=1, hidden_dim=32, output_dim=modulus)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


## training

In [8]:

# --- Training ---
batch_size = 32
dataset = torch.utils.data.TensorDataset(x_train, y_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

epochs = 100
for epoch in range(epochs):
    total_loss = 0.0
    correct = 0
    total = 0
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_x)  # shape: [batch_size, modulus]
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch_x.size(0)
        # Calculate accuracy.
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == batch_y).sum().item()
        total += batch_y.size(0)
    
    avg_loss = total_loss / total
    accuracy = correct / total * 100
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
