In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from early_stopping_pytorch import EarlyStopping

# Define the synthetic 1D function
def target_function(x):
    return np.sin(3 * x) + 0.5 * np.cos(5 * x)

# Generate synthetic data
np.random.seed(42)
X = np.linspace(-2, 2, 50).reshape(-1, 1)
y = target_function(X) + 0.1 * np.random.normal(size=X.shape)

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Define the MLP model
class MLPRegressor(nn.Module):
    def __init__(self):
        super(MLPRegressor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

# Initialize the model, loss, and optimizer
model = MLPRegressor()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Initialize early stopping
early_stopping = EarlyStopping(patience=20, verbose=True)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    predictions = model(X_train_tensor)
    loss = criterion(predictions, y_train_tensor)
    loss.backward()
    optimizer.step()

    # Evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_predictions = model(X_test_tensor)
        val_loss = criterion(val_predictions, y_test_tensor).item()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

    # Check for early stopping
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

# Load the best model
model.load_state_dict(torch.load('checkpoint.pt'))

# Evaluate the model
model.eval()
X_plot = torch.tensor(np.linspace(-2, 2, 500).reshape(-1, 1), dtype=torch.float32)
with torch.no_grad():
    y_plot = model(X_plot).numpy()

# Visualization with Plotly
fig = go.Figure()

# Plot the true function
fig.add_trace(go.Scatter(x=X_plot.flatten().numpy(), y=target_function(X_plot.flatten().numpy()),
                         mode='lines', name='True Function'))

# Plot training data
fig.add_trace(go.Scatter(x=X_train.flatten(), y=y_train.flatten(),
                         mode='markers', name='Training Data'))

# Plot model predictions
fig.add_trace(go.Scatter(x=X_plot.flatten().numpy(), y=y_plot.flatten(),
                         mode='lines', name='Model Prediction'))

# Customize layout
fig.update_layout(title='MLP Regression on 1D Function with Early Stopping',
                  xaxis_title='x', yaxis_title='y',
                  legend=dict(x=0.01, y=0.99))

fig.show()


Epoch 1, Loss: 0.7467, Val Loss: 0.7755
Validation loss decreased (inf --> 0.775496).  Saving model ...
Epoch 2, Loss: 0.5378, Val Loss: 0.8541
EarlyStopping counter: 1 out of 20
Epoch 3, Loss: 0.5598, Val Loss: 0.8266
EarlyStopping counter: 2 out of 20
Epoch 4, Loss: 0.5345, Val Loss: 0.7492
Validation loss decreased (0.775496 --> 0.749175).  Saving model ...
Epoch 5, Loss: 0.4881, Val Loss: 0.7026
Validation loss decreased (0.749175 --> 0.702641).  Saving model ...
Epoch 6, Loss: 0.4675, Val Loss: 0.6825
Validation loss decreased (0.702641 --> 0.682471).  Saving model ...
Epoch 7, Loss: 0.4623, Val Loss: 0.6612
Validation loss decreased (0.682471 --> 0.661151).  Saving model ...
Epoch 8, Loss: 0.4495, Val Loss: 0.6243
Validation loss decreased (0.661151 --> 0.624275).  Saving model ...
Epoch 9, Loss: 0.4245, Val Loss: 0.5785
Validation loss decreased (0.624275 --> 0.578539).  Saving model ...
Epoch 10, Loss: 0.3951, Val Loss: 0.5364
Validation loss decreased (0.578539 --> 0.536431). 

  model.load_state_dict(torch.load('checkpoint.pt'))
