In [1]:
import torch
from torch import nn

class MLP(nn.Module):
    l: nn.ModuleList
    f: nn.Module

    def __init__(self, *sizes: int, f: nn.Module = nn.ReLU()):
        super().__init__() # type: ignore
        if len(sizes) < 2:
            raise ValueError(f"`sizes` must contain at least 2 elements (input size and at least one output size), got {sizes} instead.")
        self.f = f
        self.l = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.l.append(nn.Linear(sizes[i], sizes[i+1]))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, l in enumerate(self.l):
            x = l(x)
            if i < len(self.l) - 1:
                x = self.f(x)
        return x

In [None]:
import numpy as np
import torch
from tqdm import tqdm

def train(
        X: np.typing.ArrayLike, y: np.typing.ArrayLike, 
        model: nn.Module, loss_fn: nn.Module, 
        optim: torch.optim.Optimizer, epochs: int = 50, step: float = .2,
        verbose: bool = True
    ) -> None:
    X = torch.tensor(X, dtype=torch.float32) # type: ignore
    y = torch.tensor(y, dtype=torch.float32)
    
    for _ in (bar := tqdm(range(epochs), desc="Training", unit="epoch", disable=not verbose)):
        y_ = model(X)
        loss = loss_fn(y, y_)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        bar.set_description(f"Loss: {loss.item():.4f}")

In [None]:
class FFNN:
    def __init__(self, hidden_layer_sizes=(100,), activation=nn.ReLU(), 
                 loss_fn=nn.MSELoss(), optimizer=torch.optim.Adam, lr=0.001, 
                 epochs=50, step=0.2, verbose=True):
        """
        A scikit-learn compatible feed-forward neural network using PyTorch.
        
        Parameters:
        -----------
        hidden_layer_sizes : tuple
            The size of hidden layers.
        activation : nn.Module
            The activation function for the hidden layers.
        loss_fn : nn.Module
            The loss function to use.
        optimizer : torch.optim.Optimizer
            The optimizer class to use.
        lr : float
            The learning rate for the optimizer.
        epochs : int
            Number of training epochs.
        step : float
            Step size parameter passed to train function.
        verbose : bool
            Whether to display training progress.
        """
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation = activation
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.lr = lr
        self.epochs = epochs
        self.step = step
        self.verbose = verbose
        self.model = None
        self.is_fitted_ = False
    
    def fit(self, X, y):
        """Fit the model to the data."""
        # Convert y to 2D if needed
        y = np.atleast_2d(y)
        if y.shape[0] == 1:
            y = y.T
        
        # Create model architecture
        input_size = X.shape[1]
        output_size = y.shape[1]
        
        sizes = (input_size,) + self.hidden_layer_sizes + (output_size,)
        self.model = MLP(*sizes, f=self.activation)
        
        # Create optimizer
        optim = self.optimizer(self.model.parameters(), lr=self.lr)
        
        # Train the model
        train(X, y, self.model, self.loss_fn, optim, 
              epochs=self.epochs, step=self.step, verbose=self.verbose)
        
        self.is_fitted_ = True
        return self
    
    def predict(self, X):
        """Generate predictions with the fitted model."""
        if not self.is_fitted_:
            raise ValueError("Model not fitted. Call 'fit' first.")
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Make predictions
        with torch.no_grad():
            X_tensor = torch.tensor(X, dtype=torch.float32)
            predictions = self.model(X_tensor).numpy()
        
        # Flatten if single output
        if predictions.shape[1] == 1:
            predictions = predictions.ravel()
            
        return predictions