Source : https://github.com/iSiddharth20/DeepLearning-ImageClassification-Toolkit

In [1]:
'''
Working Directories
Importing Libraries
'''

import os
import re
import matplotlib.pyplot as plt

# Directory where Training Log are Stored
LOGS_DIR = '../Logs/'

# Directory where Graphs will be Stored
GRAPH_DIR = '../OutputGraphs/'
os.makedirs(os.path.dirname(GRAPH_DIR), exist_ok=True)

In [2]:
'''
Helper Functions
'''

# Read Data from Text File
def read_data(model_name):
    model_name = LOGS_DIR+model_name+'.txt'
    with open(model_name, 'r') as file:
        return file.read()

# Extract Necessary Data from Text File
def extract_data(content):
    content = content.split('\n')
    val_loss = []
    val_accuracy = []
    val_precision = []
    val_recall = []
    val_auc = []
    for text in content :
        try :
            # Validation Metric
            val_loss.append(round(float(re.search(r'val_loss: (\d+\.\d+)', text).group(1)),2))
            val_accuracy.append(round(float(re.search(r'val_accuracy: (\d+\.\d+)', text).group(1)),3))
            val_precision.append(float(re.search(r'val_precision: (\d+\.\d+)', text).group(1)))
            val_recall.append(float(re.search(r'val_recall: (\d+\.\d+)', text).group(1)))
            f1_score = [round(a*b,3) for a, b in zip(val_precision, val_recall)]
            val_auc.append(round(float(re.search(r'val_auc: (\d+\.\d+)', text).group(1)),3))
        except:
            pass
    return val_loss,val_accuracy,f1_score,val_auc

# Find Maximum Number of Epochs   
def max_epochs(lst):
    result = 0
    for ele in lst:
        if len(ele)>result:
            result = len(ele)
    return result


# Set 'export_graph'=True to Save/Export Generated Graphs to 'GRAPH_DIR' Directory
export_graph = False

# Generate and Save Graphs of Each Metric    
def generate_graphs(max_epochs,metrix_name,metric):
    x = [i for i in range(1,max_epochs+1)]
    fig = plt.figure(figsize=(10, 8))
    plt.plot(x[:len(metric[0])], metric[0], label='EfficientNetB0')
    plt.plot(x[:len(metric[1])], metric[1], label='InceptionResNetV2')
    plt.plot(x[:len(metric[2])], metric[2], label='ResNet50')
    plt.plot(x[:len(metric[3])], metric[3], label='VGG16')
    plt.xlabel('Number of Epochs')
    plt.ylabel(metrix_name)
    plt.legend()
    plt.grid(True)
    plt.title(metrix_name+' Comparison')
    if export_graph == True :
        plt.savefig(GRAPH_DIR+metrix_name+'.png', bbox_inches = 'tight')
    plt.show()

In [3]:
'''
Importing Contents from Text Files
'''

# EfficientNetB0
content_EfficientNetB0 = read_data('EfficientNetB0')
val_loss_EfficientNetB0,val_accuracy_EfficientNetB0,f1_score_EfficientNetB0,val_auc_EfficientNetB0 = extract_data(content_EfficientNetB0)

# InceptionResNetV2
content_InceptionResNetV2 = read_data('InceptionResNetV2')
val_loss_InceptionResNetV2,val_accuracy_InceptionResNetV2,f1_score_InceptionResNetV2,val_auc_InceptionResNetV2 = extract_data(content_InceptionResNetV2)

# ResNet50
content_ResNet50 = read_data('ResNet50')
val_loss_ResNet50,val_accuracy_ResNet50,f1_score_ResNet50,val_auc_ResNet50 = extract_data(content_ResNet50)

# VGG16
content_VGG16 = read_data('VGG16')
val_loss_VGG16,val_accuracy_VGG16,f1_score_VGG16,val_auc_VGG16 = extract_data(content_VGG16)

In [4]:
'''
Generating and Saving Graphs
    - Ideally, For all values, Higher and Quicker the Better
'''
# Validation Loss
val_loss = [val_loss_EfficientNetB0,val_loss_InceptionResNetV2,val_loss_ResNet50,val_loss_VGG16]
generate_graphs(max_epochs(val_loss),'Validation Loss',val_loss)

# Validation Accuracy
val_accuracy = [val_accuracy_EfficientNetB0,val_accuracy_InceptionResNetV2,val_accuracy_ResNet50,val_accuracy_VGG16]
generate_graphs(max_epochs(val_accuracy),'Validation Accuracy',val_accuracy)

# F-1 Score
f1_score = [f1_score_EfficientNetB0,f1_score_InceptionResNetV2,f1_score_ResNet50,f1_score_VGG16]
generate_graphs(max_epochs(f1_score),'F-1 Score',f1_score)

# Validation AUC (Area under ROC Curve)
val_auc = [val_auc_EfficientNetB0,val_auc_InceptionResNetV2,val_auc_ResNet50,val_auc_VGG16]
generate_graphs(max_epochs(val_auc),'Validation AUC',val_auc)