In [None]:
def train_model(model, EPOCHS, train_loader, val_loader, optimizer, loss_function):

    best_validation_loss = np.inf

    loss_train = []  # store the values of the loss for the training set
    loss_val = []    # store the values of the loss for the validation set
    accuracies_train = []  # store the values of the accuracy for the training set
    accuracies_val = []    # store the values of the accuracy for the validation set

    for epoch in range(EPOCHS):
        print('EPOCH {}:'.format(epoch + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_function)
        loss_train.append(train_loss)

        running_validation_loss = 0.0

        # If using dropout and/or batch normalization, set the model to evaluation mode for validation
        model.eval()

        correct = 0
        total = 0

        with torch.no_grad():  # disable gradient computation and reduce memory consumption
            for i, vdata in enumerate(val_loader):
                vinputs, vlabels = vdata
                voutputs = model(vinputs)
                vloss = loss_function(voutputs, vlabels)
                running_validation_loss += vloss
                _, predicted = torch.max(voutputs.data, 1)
                total += vlabels.size(0)
                correct += (predicted == vlabels).sum().item()

            accuracy = 100 * correct / total
            accuracies_val.append(accuracy)

        validation_loss = running_validation_loss / (i + 1)  # average validation loss per minibatch
        loss_val.append(validation_loss)

        print('LOSS: train: {}, validation: {}; accuracy validation set: {}%\n'.format(train_loss, validation_loss,
                                                                                        accuracy))

        # Track best performance (based on validation), and save the model
        if validation_loss < best_validation_loss:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            best_validation_loss = validation_loss
            model_path = 'model_{}_{}'.format(timestamp, epoch)
            torch.save(model.state_dict(), model_path)

        # Evaluate on the training set to calculate accuracy
        model.eval()

        running_train_correct = 0
        running_train_total = 0

        with torch.no_grad():
            for i, tdata in enumerate(train_loader):
                tinputs, tlabels = tdata
                toutputs = model(tinputs)
                _, predicted = torch.max(toutputs.data, 1)
                running_train_total += tlabels.size(0)
                running_train_correct += (predicted == tlabels).sum().item()

            train_accuracy = 100 * running_train_correct / running_train_total
            accuracies_train.append(train_accuracy)

    return model_path, loss_train, loss_val, accuracies_train, accuracies_val
