# train_validate.py test
---

In [6]:
%config Completer.use_jedi = False
# import libraries
import torch
import numpy as np

torch.manual_seed(7)

<torch._C.Generator at 0x29b58977570>

### Train/Validation/Test splits with random_split

In [7]:
from torchvision import datasets
import torchvision.transforms as transforms

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 20
# percentage of training set to use as validation
valid_size = 0.2

# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# choose the training and test datasets
full_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# split training dataset in training and validation
n = len(full_data)
n_val = int(n * valid_size)
train_data, valid_data = torch.utils.data.random_split(full_data, [n-n_val, n_val])

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [8]:
assert len(train_loader.sampler) == 8000
assert len(valid_loader.sampler) == 2000

---
## Define the Network [Architecture](http://pytorch.org/docs/stable/nn.html)

The architecture will be responsible for seeing as input a 784-dim Tensor of pixel values for each image, and producing a Tensor of length 10 (our number of classes) that indicates the class scores for an input image. This particular example uses two hidden layers and dropout to avoid overfitting.

In [9]:
import torch.nn as nn
import torch.nn.functional as F

# define the NN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # number of hidden nodes in each layer (512)
        hidden_1 = 512
        # linear layer (784 -> hidden_1)
        self.fc1 = nn.Linear(28 * 28, hidden_1)
        # linear layer (n_hidden -> 10)
        self.fc3 = nn.Linear(hidden_1, 10)
        # dropout layer (p=0.2)
        # dropout prevents overfitting of data
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        # flatten image input
        x = x.view(-1, 28 * 28)
        # add hidden layer, with relu activation function
        x = F.relu(self.fc1(x))
        # add dropout layer
        x = self.dropout(x)
        # add output layer
        x = self.fc3(x)
        return x

# initialize the NN
model = Net()
print(model)

Net(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=10, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


###  Specify [Loss Function](http://pytorch.org/docs/stable/nn.html#loss-functions) and [Optimizer](http://pytorch.org/docs/stable/optim.html)

It's recommended that you use cross-entropy loss for classification. If you look at the documentation (linked above), you can see that PyTorch's cross entropy function applies a softmax funtion to the output layer *and* then calculates the log loss.

In [10]:
# specify loss function (categorical cross-entropy)
criterion = nn.CrossEntropyLoss()

# specify optimizer (stochastic gradient descent) and learning rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

### Train using train_validate.py function

In [11]:
import torchmetrics
import matplotlib.pyplot as plt
from train_validate import train_validate

%matplotlib inline

train_losses = []
valid_losses = []
def on_epoch_end(model, epoch, epoch_train_loss, epoch_valid_loss, best_valid_loss, train_metrics, valid_metrics):
    train_losses.append(epoch_train_loss)
    valid_losses.append(epoch_valid_loss)    
    print(f'Epoch {epoch}. Val loss: {round(epoch_valid_loss, 4)}. Val Accuracy: {round(valid_metrics["Acc"] * 100, 1)}%')
    if epoch_valid_loss > best_valid_loss:                    
        # Stop if validation loss has increased
        print(f'   Validation loss increased from {best_valid_loss} to {epoch_valid_loss}. Stopping...')
    else:    
        print('   Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(best_valid_loss, epoch_valid_loss))

best_model_state = train_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=2, on_epoch_end=on_epoch_end, metric_factories={'Acc': torchmetrics.Accuracy}, stop_on_loss_increase=True )
model.load_state_dict(best_model_state)
torch.save(model.state_dict(), 'model.pt')

print('Done. Saved to model.pt')


Epoch 0. Val loss: 1.0574. Val Accuracy: 81.3%
   Validation loss decreased (inf --> 1.057446).  Saving model ...
Epoch 1. Val loss: 0.5999. Val Accuracy: 86.9%
   Validation loss decreased (1.057446 --> 0.599906).  Saving model ...
Epoch 2. Val loss: 0.4669. Val Accuracy: 88.3%
   Validation loss decreased (0.599906 --> 0.466851).  Saving model ...
Epoch 3. Val loss: 0.4061. Val Accuracy: 89.4%
   Validation loss decreased (0.466851 --> 0.406059).  Saving model ...
Epoch 4. Val loss: 0.3701. Val Accuracy: 90.2%
   Validation loss decreased (0.406059 --> 0.370137).  Saving model ...
Epoch 5. Val loss: 0.3478. Val Accuracy: 90.5%
   Validation loss decreased (0.370137 --> 0.347828).  Saving model ...
Epoch 6. Val loss: 0.3304. Val Accuracy: 90.9%
   Validation loss decreased (0.347828 --> 0.330384).  Saving model ...
Epoch 7. Val loss: 0.3185. Val Accuracy: 91.0%
   Validation loss decreased (0.330384 --> 0.318505).  Saving model ...
Epoch 8. Val loss: 0.3077. Val Accuracy: 91.4%
   Val

KeyboardInterrupt: 

In [None]:
train_losses

In [None]:
valid_losses

In [None]:
# assert np.allclose(out, expected_out)