In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split

# EarlyStopping class implementation (replace with your implementation if already defined)
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_val_loss = None
        self.early_stop = False
        self.delta = delta
        self.trace_func = trace_func
        self.best_model_state = None  # Hold the best model state in memory

    def __call__(self, val_loss, model):
        if self.best_val_loss is None:
            self.best_val_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss < self.best_val_loss - self.delta:
            self.best_val_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0  # Reset counter since improvement occurred
        else:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased. Saving model state...')
        self.best_model_state = model.state_dict()

# Define the synthetic 1D function
def target_function(x):
    return np.sin(3 * x) + 0.5 * np.cos(5 * x)

# Generate synthetic data
np.random.seed(42)
X = np.linspace(-2, 2, 50).reshape(-1, 1)
y = target_function(X) + 0.1 * np.random.normal(size=X.shape)

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Define the MLP model
class MLPRegressor(nn.Module):
    def __init__(self):
        super(MLPRegressor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

# Initialize the model, loss, and optimizer
model = MLPRegressor()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Initialize early stopping
early_stopping = EarlyStopping(patience=20, verbose=True)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    predictions = model(X_train_tensor)
    loss = criterion(predictions, y_train_tensor)
    loss.backward()
    optimizer.step()

    # Evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_predictions = model(X_test_tensor)
        val_loss = criterion(val_predictions, y_test_tensor).item()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

    # Check for early stopping
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

# Load the best model state from memory
if early_stopping.best_model_state is not None:
    model.load_state_dict(early_stopping.best_model_state)
    print("Best model state restored.")

# Evaluate the model
model.eval()
X_plot = torch.tensor(np.linspace(-2, 2, 500).reshape(-1, 1), dtype=torch.float32)
with torch.no_grad():
    y_plot = model(X_plot).numpy()

# Visualization with Plotly
fig = go.Figure()

# Plot the true function
fig.add_trace(go.Scatter(x=X_plot.flatten().numpy(), y=target_function(X_plot.flatten().numpy()),
                         mode='lines', name='True Function'))

# Plot training data
fig.add_trace(go.Scatter(x=X_train.flatten(), y=y_train.flatten(),
                         mode='markers', name='Training Data'))

# Plot model predictions
fig.add_trace(go.Scatter(x=X_plot.flatten().numpy(), y=y_plot.flatten(),
                         mode='lines', name='Model Prediction'))

# Customize layout
fig.update_layout(title='MLP Regression on 1D Function with Early Stopping (In-Memory State)',
                  xaxis_title='x', yaxis_title='y',
                  legend=dict(x=0.01, y=0.99))

fig.show()


Epoch 1, Loss: 0.6870, Val Loss: 0.8433
Validation loss decreased. Saving model state...
Epoch 2, Loss: 0.5347, Val Loss: 0.8588
EarlyStopping counter: 1 out of 20
Epoch 3, Loss: 0.5414, Val Loss: 0.7754
Validation loss decreased. Saving model state...
Epoch 4, Loss: 0.4947, Val Loss: 0.6962
Validation loss decreased. Saving model state...
Epoch 5, Loss: 0.4670, Val Loss: 0.6447
Validation loss decreased. Saving model state...
Epoch 6, Loss: 0.4591, Val Loss: 0.6042
Validation loss decreased. Saving model state...
Epoch 7, Loss: 0.4352, Val Loss: 0.5813
Validation loss decreased. Saving model state...
Epoch 8, Loss: 0.4029, Val Loss: 0.5673
Validation loss decreased. Saving model state...
Epoch 9, Loss: 0.3841, Val Loss: 0.5374
Validation loss decreased. Saving model state...
Epoch 10, Loss: 0.3701, Val Loss: 0.4782
Validation loss decreased. Saving model state...
Epoch 11, Loss: 0.3418, Val Loss: 0.4095
Validation loss decreased. Saving model state...
Epoch 12, Loss: 0.3132, Val Loss: