In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

from torchvision import datasets
from torch.utils.data import DataLoader
from torch import optim
from early_stopping import EarlyStopping

import numpy as np
import matplotlib.pyplot as plt

### Dataset

In [2]:
## load training and test data
train_data = datasets.MNIST(root='./mnist',
                            train = True,
                            transform=torchvision.transforms.ToTensor(),
                            download=True)

test_data = datasets.MNIST(root='./mnist',
                           train = False,
                           transform=torchvision.transforms.ToTensor(),
                           download=True)

In [3]:
## split training data into training and validation\
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(num_train * 0.25))
train_idx, valid_idx = indices[split:], indices[:split]
    

# define samplers for obtaining training and validation batches
from torch.utils.data.sampler import SubsetRandomSampler

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx) 

## load data into batches
BATCH_SIZE = 100

train_loader = DataLoader(dataset=train_data, 
                          batch_size=BATCH_SIZE, 
                          sampler=train_sampler,
                          num_workers=0)

validation_loader = DataLoader(dataset=train_data, 
                               batch_size=BATCH_SIZE, 
                               sampler=valid_sampler,
                               num_workers=0)

test_loader = DataLoader(dataset=test_data, 
                         batch_size=BATCH_SIZE, 
                         num_workers=0)

### Autoencoder Model

In [4]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.fc1 = nn.Linear(28*28, 218)
        self.fc2 = nn.Linear(218,64)
        self.fc3 = nn.Linear(64,16)
        self.fc4 = nn.Linear(16,2)
        
        self.fc5 = nn.Linear(2,16)
        self.fc6 = nn.Linear(16, 64)
        self.fc7 = nn.Linear(64, 218)
        self.fc8 = nn.Linear(218, 28*28)

        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = F.relu(self.fc8(x))
        return x


model = Autoencoder()


In [5]:
# loss function
criterion = nn.MSELoss()  

# optimizer
optimizer = optim.Adam(model.parameters())

In [6]:
def model_train(model, patience, max_epochs):
    
    train_losses = []
    valid_losses = []
    avg_train_losses = []  # average loss in epoch
    avg_valid_losses = [] 
    
    # early stopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, max_epochs + 1):
        ##X_train, y_train = shuffle(X_train, y_train)

        #### 1.TRAINING PROCESS
        model.train()
        for batch in train_loader:
            x, y = batch
            #x = x.view(x.size(0), -1)
            #y = y.to(torch.float32)
            optimizer.zero_grad()
            output = model(x)
            #print('x: ', x.shape)
            #print('output:', output.shape)

            loss = criterion(output, x.view(-1, 28*28))
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())

   
        #### 2.VALIDATION PROCESS
        model.eval() # prep model for evaluation
        for batch in validation_loader:
            x, y = batch
            #x = x.view(x.size(0), -1)
        
            #y = y.to(torch.float32)
            output = model(x)
            loss = criterion(output, x.view(-1, 28*28))

            valid_losses.append(loss.item())

        #### 3.training/validation statistics 
        # average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(max_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{max_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses

    

In [7]:
# batch_size = 64
patience = 35
max_epochs = 500

# model_, train_loss, val_loss = model_train(model, batch_size, patience, max_epochs)
model_, train_loss, val_loss = model_train(model, patience, max_epochs)


[  1/500] train_loss: 0.07025 valid_loss: 0.06075
Validation loss decreased (inf --> 0.060749).  Saving model ...
[  2/500] train_loss: 0.05627 valid_loss: 0.05388
Validation loss decreased (0.060749 --> 0.053884).  Saving model ...
[  3/500] train_loss: 0.05261 valid_loss: 0.05131
Validation loss decreased (0.053884 --> 0.051307).  Saving model ...
[  4/500] train_loss: 0.04964 valid_loss: 0.04882
Validation loss decreased (0.051307 --> 0.048817).  Saving model ...
[  5/500] train_loss: 0.04812 valid_loss: 0.04797
Validation loss decreased (0.048817 --> 0.047973).  Saving model ...
[  6/500] train_loss: 0.04705 valid_loss: 0.04716
Validation loss decreased (0.047973 --> 0.047161).  Saving model ...
[  7/500] train_loss: 0.04636 valid_loss: 0.04652
Validation loss decreased (0.047161 --> 0.046524).  Saving model ...
[  8/500] train_loss: 0.04576 valid_loss: 0.04565
Validation loss decreased (0.046524 --> 0.045650).  Saving model ...
[  9/500] train_loss: 0.04513 valid_loss: 0.04559
Val

[ 80/500] train_loss: 0.03706 valid_loss: 0.03809
EarlyStopping counter: 1 out of 35
[ 81/500] train_loss: 0.03705 valid_loss: 0.03811
EarlyStopping counter: 2 out of 35
[ 82/500] train_loss: 0.03700 valid_loss: 0.03799
EarlyStopping counter: 3 out of 35
[ 83/500] train_loss: 0.03698 valid_loss: 0.03810
EarlyStopping counter: 4 out of 35
[ 84/500] train_loss: 0.03690 valid_loss: 0.03799
EarlyStopping counter: 5 out of 35
[ 85/500] train_loss: 0.03677 valid_loss: 0.03837
EarlyStopping counter: 6 out of 35
[ 86/500] train_loss: 0.03683 valid_loss: 0.03811
EarlyStopping counter: 7 out of 35
[ 87/500] train_loss: 0.03711 valid_loss: 0.03796
Validation loss decreased (0.037977 --> 0.037958).  Saving model ...
[ 88/500] train_loss: 0.03675 valid_loss: 0.03800
EarlyStopping counter: 1 out of 35
[ 89/500] train_loss: 0.03689 valid_loss: 0.03778
Validation loss decreased (0.037958 --> 0.037780).  Saving model ...
[ 90/500] train_loss: 0.03695 valid_loss: 0.03795
EarlyStopping counter: 1 out of 

[173/500] train_loss: 0.03632 valid_loss: 0.03736
EarlyStopping counter: 26 out of 35
[174/500] train_loss: 0.03628 valid_loss: 0.03742
EarlyStopping counter: 27 out of 35
[175/500] train_loss: 0.03622 valid_loss: 0.03752
EarlyStopping counter: 28 out of 35
[176/500] train_loss: 0.03624 valid_loss: 0.03756
EarlyStopping counter: 29 out of 35
[177/500] train_loss: 0.03606 valid_loss: 0.03728
EarlyStopping counter: 30 out of 35
[178/500] train_loss: 0.03610 valid_loss: 0.03730
EarlyStopping counter: 31 out of 35
[179/500] train_loss: 0.03619 valid_loss: 0.03730
EarlyStopping counter: 32 out of 35
[180/500] train_loss: 0.03605 valid_loss: 0.03723
EarlyStopping counter: 33 out of 35
[181/500] train_loss: 0.03623 valid_loss: 0.03753
EarlyStopping counter: 34 out of 35
[182/500] train_loss: 0.03597 valid_loss: 0.03702
EarlyStopping counter: 35 out of 35
Early stopping
