In [1]:
from keras.callbacks import BaseLogger
import matplotlib.pyplot as plt
import numpy as np
import json
import os

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
class TrainingMonitor(BaseLogger):
    def __init__(self, figPath, jsonPath=None, startAt=0):
        # Store the output path for the figure, the path to the JSON
        # serialized file and the starting epoch
        super(TrainingMonitor, self).__init__()
        self.figPath = figPath
        self.jsonPath = jsonPath
        self.startAt = startAt
        
    def on_train_begin(self, logs={}):
        # Initalize the history dictionary
        self.H = {}
        
        # If the JSON History path exists, load the traingin history
        if self.jsonPath is not None:
            if os.path.exists(self.jsonPath):
                self.H = json.loads(open(self.jsonPath).read())
                
                # Check to see if a starting epoch was supplied
                if self.startAt > 0:
                    # Loop over the entries in the history log and trim any entries
                    # that are past the starting epoch
                    for k in self.H.keys():
                        self.H[k] = self.H[k][:self.startAt]
                        
    def on_epoch_end(self, epoch, logs={}):
        # Loop over the logs and update the loss, accuracy etc
        # For the entire training process
        for (k, v) in logs.items():
            l = self.H.get(k, [])
            l.append(v)
            self.h[k] = l
            
        # Check to see if training history should be serialized to file
        if self.jsonPath is not None:
            f = open(self.jsonPath, "w")
            f.write(json.dumps(self.H))
            f.close()
            
        # Ensure at least two epochs have passed before plotting (epoch starts at zero)
        if len(self.H["loss"]) > 1:
            # Plot the training loss and accuracy
            N = np.arange(0, len(self.H["loss"]))
            plt.style.use("ggplot")
            plt.figure()
            plt.plot(N, self.H["loss"], label="trian_loss")
            plt.plot(N, self.H["val_loss"], label="validation_loss")
            plt.plot(N, self.H["acc"], label="train_accuracy")
            plt.plot(N, self.H["val_acc"], label="validation_accuracy")
            plt.title("Training loss and accuracy [Epoch {}]".format(
                len(self.H["loss"])))
            plt.xlabel("Epoch #")
            plt.ylabel("Loss/Accuracy")
            plt.legend()
            
            plt.savefig(self.figPath)
            plt.close