In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
sns.set()


class Extractor(object):
    def __init__(self, path):
        with open(path, "r") as f:
            lines = f.readlines()
        self.epochlines = {}
        for mode in ["train", "test"]:
            key = "({}) Epoch".format(mode)
            self.epochlines[mode] = [line for line in lines if key in line]
        self.iterlines = {}
        for mode in ["train", "test"]:
            key = "({}) Iter".format(mode)
            self.iterlines[mode] = [line for line in lines if key in line]

    def __call__(self, mode, key, epoch=True):
        if epoch:
            lines = self.epochlines[mode]
        else:
            lines = self.iterlines[mode]
        key = "\'{}\'".format(key)
        lines = [re.split('[{}]', line)[1].split(', ') for line in lines]
        lines = [[item for item in line if key in item] for line in lines]
        lines = [float(line[0].split()[1]) for line in lines]
        return np.array(lines)


def get_b_t(data1, data2):
    b = min(data1.min(), data2.min())
    t = max(data1.max(), data2.max())
    return b, t


def show_hist(logfile):
    ext = Extractor(logfile)
    plt.figure(figsize=(12,16))
    plt.suptitle(logfile, x=0.5, y=0.92)

    data1 = ext("train", "x_loss")
    data2 = ext("test", "x_loss")
    plt.subplot(4,1,1)
    plt.title("x_loss")
    b, t = get_b_t(data1[10:], data2[10:])
    plt.ylim([b-(t-b)*0.1, t+(t-b)*0.1])
    sns.lineplot(data=data1)
    sns.lineplot(data=data2)

    data1 = ext("train", "s_loss")
    data2 = ext("test", "s_loss")
    plt.subplot(4,1,2)
    plt.title("s_loss")
    b, t = get_b_t(data1[10:], data2[10:])
    plt.ylim([b-(t-b)*0.1, t+(t-b)*0.1])
    sns.lineplot(data=data1)
    sns.lineplot(data=data2)

    data1 = ext("train", "s_aux_loss")
    data2 = ext("test", "s_aux_loss")
    plt.subplot(4,1,3)
    plt.title("s_aux_loss")
    b, t = get_b_t(data1[10:], data2[10:])
    plt.ylim([b-(t-b)*0.1, t+(t-b)*0.1])
    sns.lineplot(data=data1)
    sns.lineplot(data=data2)

    data1 = ext("train", "g_grad_norm", epoch=False)
    plt.subplot(4,1,4)
    b, t = data1[1000:].min(), data1[1000:].max()
    plt.title("g_grad_norm (max:{})".format(t))
    plt.ylim([b-(t-b)*0.1, t+(t-b)*0.1])
    sns.lineplot(data=data1)
    plt.axhline(y=1.2e+5, xmin=0, xmax=len(data1), c="k", linestyle=":")

    plt.show()

In [2]:
show_hist("logzero/Jul09_07-09-44.txt")