https://github.com/jeffheaton/app_deep_learning/blob/main/t81_558_class_03_4_early_stop.ipynb

In [15]:
import torch

try:
    import google.colab

    COLAB = True
    print("Note: using Google CoLab")
except:
    print("Note: not using Google CoLab")
    COLAB = False

# Make use of a GPU or MPS (Apple) if one is available.  (see module 3.2)
device = (
    "mps"
    if getattr(torch, "has_mps", False)
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Note: not using Google CoLab
Using device: cpu


## Iris dataset

In [9]:
from sklearn import datasets

iris = datasets.load_iris()
print(iris.keys())
print(iris.DESCR)

dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

                    Min  Max   Mean    SD   Class Correlation
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for ea

In [10]:
def namedata(data):
    return {
        'sepal_l': data[0],
        'sepal_w': data[1],
        'petal_l': data[2],
        'petal_w': data[3]
    }

table = list(map(namedata, iris.data))
table

[{'sepal_l': 5.1, 'sepal_w': 3.5, 'petal_l': 1.4, 'petal_w': 0.2},
 {'sepal_l': 4.9, 'sepal_w': 3.0, 'petal_l': 1.4, 'petal_w': 0.2},
 {'sepal_l': 4.7, 'sepal_w': 3.2, 'petal_l': 1.3, 'petal_w': 0.2},
 {'sepal_l': 4.6, 'sepal_w': 3.1, 'petal_l': 1.5, 'petal_w': 0.2},
 {'sepal_l': 5.0, 'sepal_w': 3.6, 'petal_l': 1.4, 'petal_w': 0.2},
 {'sepal_l': 5.4, 'sepal_w': 3.9, 'petal_l': 1.7, 'petal_w': 0.4},
 {'sepal_l': 4.6, 'sepal_w': 3.4, 'petal_l': 1.4, 'petal_w': 0.3},
 {'sepal_l': 5.0, 'sepal_w': 3.4, 'petal_l': 1.5, 'petal_w': 0.2},
 {'sepal_l': 4.4, 'sepal_w': 2.9, 'petal_l': 1.4, 'petal_w': 0.2},
 {'sepal_l': 4.9, 'sepal_w': 3.1, 'petal_l': 1.5, 'petal_w': 0.1},
 {'sepal_l': 5.4, 'sepal_w': 3.7, 'petal_l': 1.5, 'petal_w': 0.2},
 {'sepal_l': 4.8, 'sepal_w': 3.4, 'petal_l': 1.6, 'petal_w': 0.2},
 {'sepal_l': 4.8, 'sepal_w': 3.0, 'petal_l': 1.4, 'petal_w': 0.1},
 {'sepal_l': 4.3, 'sepal_w': 3.0, 'petal_l': 1.1, 'petal_w': 0.1},
 {'sepal_l': 5.8, 'sepal_w': 4.0, 'petal_l': 1.2, 'petal_w': 0

In [11]:
tmap = {0: 'Iris-setosa', 1: 'Iris-versicolor', 2: 'Iris-virginica'}
targets = list(map(lambda item: tmap[item], iris.target))

In [12]:
table = [{**row, 'species': name} for row, name in zip(table, targets)]
table

[{'sepal_l': 5.1,
  'sepal_w': 3.5,
  'petal_l': 1.4,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.9,
  'sepal_w': 3.0,
  'petal_l': 1.4,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.7,
  'sepal_w': 3.2,
  'petal_l': 1.3,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.6,
  'sepal_w': 3.1,
  'petal_l': 1.5,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 5.0,
  'sepal_w': 3.6,
  'petal_l': 1.4,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 5.4,
  'sepal_w': 3.9,
  'petal_l': 1.7,
  'petal_w': 0.4,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.6,
  'sepal_w': 3.4,
  'petal_l': 1.4,
  'petal_w': 0.3,
  'species': 'Iris-setosa'},
 {'sepal_l': 5.0,
  'sepal_w': 3.4,
  'petal_l': 1.5,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.4,
  'sepal_w': 2.9,
  'petal_l': 1.4,
  'petal_w': 0.2,
  'species': 'Iris-setosa'},
 {'sepal_l': 4.9,
  'sepal_w': 3.1,
  'petal_l': 1.5,
  'petal_w': 0.1,
  'species': 'Iris-

In [13]:
import pandas as pd
df = pd.DataFrame(table)
df

Unnamed: 0,sepal_l,sepal_w,petal_l,petal_w,species
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,Iris-virginica
146,6.3,2.5,5.0,1.9,Iris-virginica
147,6.5,3.0,5.2,2.0,Iris-virginica
148,6.2,3.4,5.4,2.3,Iris-virginica


In [16]:
import copy


class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_model = None
        self.best_loss = None
        self.counter = 0
        self.status = ""

    def __call__(self, model, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = copy.deepcopy(model.state_dict())
        elif self.best_loss - val_loss >= self.min_delta:
            self.best_model = copy.deepcopy(model.state_dict())
            self.best_loss = val_loss
            self.counter = 0
            self.status = f"Improvement found, counter reset to {self.counter}"
        else:
            self.counter += 1
            self.status = f"No improvement in the last {self.counter} epochs"
            if self.counter >= self.patience:
                self.status = f"Early stopping triggered after {self.counter} epochs."
                if self.restore_best_weights:
                    model.load_state_dict(self.best_model)
                return True
        return False

In [17]:
import time

import numpy as np
import pandas as pd
import torch
import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def load_data():
    src = df.copy()

    le = LabelEncoder()

    x = src[["sepal_l", "sepal_w", "petal_l", "petal_w"]].values
    y = le.fit_transform(df["species"])
    species = le.classes_

    # Split into validation and training sets
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.25, random_state=42
    )

    scaler = StandardScaler()
    x_train = scaler.fit_transform(x_train)
    x_test = scaler.transform(x_test)

    # Numpy to Torch Tensor
    x_train = torch.tensor(x_train, device=device, dtype=torch.float32)
    y_train = torch.tensor(y_train, device=device, dtype=torch.long)

    x_test = torch.tensor(x_test, device=device, dtype=torch.float32)
    y_test = torch.tensor(y_test, device=device, dtype=torch.long)

    return x_train, x_test, y_train, y_test, species


x_train, x_test, y_train, y_test, species = load_data()

In [18]:
import time

import numpy as np
import pandas as pd
import torch
import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

# Create datasets
BATCH_SIZE = 16

dataset_train = TensorDataset(x_train, y_train)
dataloader_train = DataLoader(
    dataset_train, batch_size=BATCH_SIZE, shuffle=True)

dataset_test = TensorDataset(x_test, y_test)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

# Create model using nn.Sequential
model = nn.Sequential(
    nn.Linear(x_train.shape[1], 50),
    nn.ReLU(),
    nn.Linear(50, 25),
    nn.ReLU(),
    nn.Linear(25, len(species)),
    nn.LogSoftmax(dim=1),
)

model = torch.compile(model,backend="aot_eager").to(device)

loss_fn = nn.CrossEntropyLoss()  # cross entropy loss

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
es = EarlyStopping()

epoch = 0
done = False
while epoch < 1000 and not done:
    epoch += 1
    steps = list(enumerate(dataloader_train))
    pbar = tqdm.tqdm(steps)
    model.train()
    for i, (x_batch, y_batch) in pbar:
        y_batch_pred = model(x_batch.to(device))
        loss = loss_fn(y_batch_pred, y_batch.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), (i + 1) * len(x_batch)
        if i == len(steps) - 1:
            model.eval()
            pred = model(x_test)
            vloss = loss_fn(pred, y_test)
            if es(model, vloss):
                done = True
            pbar.set_description(
                f"Epoch: {epoch}, tloss: {loss}, vloss: {vloss:>7f}, {es.status}"
            )
        else:
            pbar.set_description(f"Epoch: {epoch}, tloss {loss:}")

Epoch: 1, tloss: 0.6026307344436646, vloss: 0.536555, : 100%|██████████| 7/7 [00:00<00:00,  7.70it/s]
Epoch: 2, tloss: 0.3658648133277893, vloss: 0.277725, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 401.73it/s]
Epoch: 3, tloss: 0.15603026747703552, vloss: 0.187535, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 414.63it/s]
Epoch: 4, tloss: 0.057948920875787735, vloss: 0.154333, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 388.62it/s]
Epoch: 5, tloss: 0.18528974056243896, vloss: 0.076723, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 396.41it/s]
Epoch: 6, tloss: 0.12420050799846649, vloss: 0.061499, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 384.29it/s]
Epoch: 7, tloss: 0.03340417519211769, vloss: 0.045322, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 426.23it/s]
Epoch: 8, tloss: 0.09452513605356216, vloss: 0.03297