In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

In [2]:
# Generate some random data for classification
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2, random_state=42)

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)


In [3]:
# Convert the data to PyTorch tensors and create DataLoader objects
train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.LongTensor(y_val))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [4]:
# Define a simple neural network with one hidden layer
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [5]:
# Initialize the model and optimizer
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [8]:
# Set the number of epochs and early stopping parameters
num_epochs = 100
patience = 5 #after patience epochs of worse performaces, STOP
tolerance = 1e-4 #tollerance over the validation loss 

# Initialize variables for early stopping
best_val_loss = float('inf')
n_epochs_no_improvement = 0 #accumulator
best_model_weights = None

In [9]:
# Train the model for a fixed number of epochs
for epoch in range(num_epochs):

    # Train the model on the training set
    running_loss = 0.0
    for i, data in enumerate(train_loader, 1):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Evaluate the model on the validation set
    with torch.no_grad():
        val_loss = 0.0
        for i, data in enumerate(val_loader, 1):
            inputs, labels = data
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    # Check if the validation loss has improved
    if val_loss < best_val_loss - tolerance:   #Early Stopping criterion
        best_val_loss = val_loss
        n_epochs_no_improvement = 0
        best_model_weights = net.state_dict()
    else:
        n_epochs_no_improvement += 1

    # Check if early stopping criteria have been met
    if n_epochs_no_improvement >= patience:
        print(f"Early stopping after epoch {epoch}")
        net.load_state_dict(best_model_weights)
        break

    # Print the training and validation loss for each epoch
    print(f"Epoch {epoch}: training loss = {running_loss / len(train_loader):.4f}, validation loss = {val_loss / len(val_loader):.4f}")


Epoch 0: training loss = 0.4292, validation loss = 0.4563
Epoch 1: training loss = 0.4076, validation loss = 0.4389
Epoch 2: training loss = 0.3888, validation loss = 0.4243
Epoch 3: training loss = 0.3731, validation loss = 0.4127
Epoch 4: training loss = 0.3600, validation loss = 0.4022
Epoch 5: training loss = 0.3491, validation loss = 0.3948
Epoch 6: training loss = 0.3410, validation loss = 0.3880
Epoch 7: training loss = 0.3336, validation loss = 0.3830
Epoch 8: training loss = 0.3279, validation loss = 0.3788
Epoch 9: training loss = 0.3235, validation loss = 0.3768
Epoch 10: training loss = 0.3197, validation loss = 0.3740
Epoch 11: training loss = 0.3169, validation loss = 0.3706
Epoch 12: training loss = 0.3143, validation loss = 0.3710
Epoch 13: training loss = 0.3120, validation loss = 0.3690
Epoch 14: training loss = 0.3102, validation loss = 0.3686
Epoch 15: training loss = 0.3090, validation loss = 0.3671
Epoch 16: training loss = 0.3075, validation loss = 0.3664
Epoch 1

In [None]:

# Save the final model weights
torch.save(net.state_dict(), "final_model_weights.pt")
