In [3]:
import torch
from torch import nn
from tqdm import tqdm

class MLP(nn.Module):
  l: nn.ModuleList
  f: nn.Module

  def __init__(self, *sizes: int, f: nn.Module = nn.ReLU()):
    super().__init__() # type: ignore
    
    if len(sizes) < 2:
      raise ValueError(f"`sizes` must contain at least 2 elements (input size and at least one output size), got {sizes} instead.")
    
    # Store the activation
    self.f = f
    
    # Create the linear layers
    self.l = nn.ModuleList()
    for i in range(len(sizes) - 1):
      self.l.append(nn.Linear(sizes[i], sizes[i+1]))
  
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    for i, l in enumerate(self.l):
      x = l(x)
      if i < len(self.l) - 1:
        x = self.f(x)
    return x

In [4]:
# Create a simple MLP model
input_size = 10
hidden_size = 20
output_size = 2
model = MLP(input_size, hidden_size, output_size)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Generate some dummy data
batch_size = 32
num_samples = 100
X = torch.randn(num_samples, input_size)
y = torch.randint(0, output_size, (num_samples,))

# Training loop
num_epochs = 50
for epoch in (bar := tqdm(range(num_epochs), desc="Training")):
    # Create random batches
    indices = torch.randperm(num_samples)
    for i in range(0, num_samples, batch_size):
        # Get batch indices
        idx = indices[i:i+batch_size]
        
        # Get batch data
        X_batch = X[idx]
        y_batch = y[idx]
        
        # Forward pass
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Print epoch stats
    with torch.no_grad():
        bar.set_description(f"L{loss.item():.4f}")

# Test the model
with torch.no_grad():
    test_data = torch.randn(10, input_size)
    predictions = model(test_data)
    print("Test predictions shape:", predictions.shape)
    print(predictions[:3])  # Show first 3 predictions

L0.4321: 100%|██████████| 50/50 [00:01<00:00, 30.12it/s]


Test predictions shape: torch.Size([10, 2])
tensor([[ 0.4419,  0.2484],
        [ 0.2633, -0.2450],
        [ 0.1177, -0.1921]])
