# Early Stopping

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.insert(0, "../src")

In [3]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn import metrics
from sklearn import model_selection

import torch

import albumentations as A

import config
import dataset
import models
import engine

In [4]:
import numpy as np
import torch


class EarlyStopping:
    def __init__(self, patience=7, mode="max", delta=0.0001):
        self.patience = patience
        self.mode = mode
        self.delta = delta
        
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        
        if mode == "max":
            self.val_score = -np.inf
        elif mode == "min":
            self.val_score = np.inf
        
    def __call__(self, epoch_score, model, model_path):
        if self.mode == "max":
            score = np.copy(epoch_score)
        elif self.mode == "min":
            score = -1.0 * epoch_score
            
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0
            self.save_checkpoint(epoch_score, model, model_path)
    
    def save_checkpoint(self, epoch_score, model, model_path):
        if epoch_score not in (np.inf, -np.inf, -np.nan, np.nan):
            print(f"Validation score improved ({self.val_score} --> {epoch_score}). Saving model!")
            torch.save(model.state_dict(), model_path)
        self.val_score = epoch_score

In [5]:
df = pd.read_csv(config.TRAIN_CSV)
df_train, df_valid = model_selection.train_test_split(df, test_size=0.1, stratify=df.digit)

train_dataset = dataset.EMNISTDataset(df_train)
valid_dataset = dataset.EMNISTDataset(df_valid)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.TRAIN_BATCH_SIZE, shuffle=True)
valid_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.TEST_BATCH_SIZE)

In [6]:
device = torch.device(config.DEVICE)
# model = models.SpinalVGG()
model = models.Model()
model.to(device)

Model(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2), padding=(13, 13), bias=False)
  (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn2_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn2_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_3): Conv2d(64, 64, kernel_size=(5, 5), stride=(2, 2), padding=(11, 11), bias=False)
  (bn2_3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [7]:
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', verbose=True, patience=7, factor=0.5
)
early_stopping = EarlyStopping(patience=7, mode="max")

In [8]:
EPOCHS = 25

for epoch in range(EPOCHS):
    engine.train(train_loader, model, optimizer, device)
    predictions, targets = engine.evaluate(valid_loader, model, device)

    predictions = np.array(predictions)
    predictions = np.argmax(predictions, axis=1)
    accuracy = metrics.accuracy_score(targets, predictions)
    print(f"Epoch: {epoch}, Accuracy={accuracy}")
    
    scheduler.step(accuracy)

Epoch: 0, Accuracy=0.09875203472599023
Epoch: 1, Accuracy=0.21052631578947367
Epoch: 2, Accuracy=0.5789473684210527
Epoch: 3, Accuracy=0.7287032013022247
Epoch: 4, Accuracy=0.8638090070537168
Epoch: 5, Accuracy=0.9180683667932719
Epoch: 6, Accuracy=0.8209441128594682
Epoch: 7, Accuracy=0.9045035268583831
Epoch: 8, Accuracy=0.9614758545849159
Epoch: 9, Accuracy=0.9799240368963646
Epoch: 10, Accuracy=0.9972870320130223
Epoch: 11, Accuracy=0.9180683667932719
Epoch: 12, Accuracy=0.9750406945198047
Epoch: 13, Accuracy=0.9772110689093869
Epoch: 14, Accuracy=0.9820944112859469
Epoch: 15, Accuracy=0.962561041779707
Epoch: 16, Accuracy=0.9153553988062941
Epoch: 17, Accuracy=0.9636462289744981
Epoch: 18, Accuracy=0.9902333152468801
Epoch: 19, Accuracy=0.9978296256104178
Epoch: 20, Accuracy=0.9956592512208355
Epoch: 21, Accuracy=0.9972870320130223
Epoch: 22, Accuracy=1.0
Epoch: 23, Accuracy=0.9924036896364623
Epoch: 24, Accuracy=1.0


In [11]:
torch.save(model.state_dict(), "../models/spinalvgg.pt")

In [12]:
model = models.SpinalVGG()
model.load_state_dict(torch.load("../models/spinalvgg.pt"))

<All keys matched successfully>

In [13]:
model.to(device)
df_test = pd.read_csv(config.TEST_CSV)
test_dataset = dataset.EMNISTTestDataset(df_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.TEST_BATCH_SIZE)

In [14]:
predictions = engine.infer(test_loader, model, device)
predictions = np.array(predictions)
predictions = np.argmax(predictions, axis=1)

In [15]:
submission = pd.DataFrame({"id": df_test.id, "digit": predictions})
submission.to_csv("../output/spinalvgg.csv", index=False)
submission.head()

Unnamed: 0,id,digit
0,2049,6
1,2050,9
2,2051,8
3,2052,0
4,2053,3
