In [None]:
from tensorflow.keras.callbacks import Callback
import matplotlib.pyplot as plt
import numpy as np

class PlotLosses(Callback):
    def __init__(self):
        super().__init__()
        self.tr_acc = []
        self.tr_loss = []
        self.val_acc = []
        self.val_loss = []

    def on_epoch_end(self, epoch, logs=None):
        # Mendapatkan history dari training
        self.tr_acc.append(logs.get('accuracy'))
        self.tr_loss.append(logs.get('loss'))
        self.val_acc.append(logs.get('val_accuracy'))
        self.val_loss.append(logs.get('val_loss'))

        # Menentukan epoch terbaik untuk loss dan accuracy
        index_loss = np.argmin(self.val_loss)
        val_lowest = self.val_loss[index_loss]
        index_acc = np.argmax(self.val_acc)
        acc_highest = self.val_acc[index_acc]
        
        Epochs = [i+1 for i in range(len(self.tr_acc))]
        loss_label = f'best epoch= {str(index_loss + 1)}'
        acc_label = f'best epoch= {str(index_acc + 1)}'
        
        plt.figure(figsize= (20, 8))
        plt.style.use('fivethirtyeight')

        plt.subplot(1, 2, 1)
        plt.plot(Epochs, self.tr_loss, 'r', label= 'Training loss')
        plt.plot(Epochs, self.val_loss, 'g', label= 'Validation loss')
        plt.scatter(index_loss + 1, val_lowest, s= 150, c= 'blue', label= loss_label)
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(Epochs, self.tr_acc, 'r', label= 'Training Accuracy')
        plt.plot(Epochs, self.val_acc, 'g', label= 'Validation Accuracy')
        plt.scatter(index_acc + 1 , acc_highest, s= 150, c= 'blue', label= acc_label)
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig(f'Losses/plot_epoch_{epoch+1}.png') #Perhatikan Directory Penyimpanan
        plt.close()

plot_losses_callback = PlotLosses()