In [None]:
import tensorboard
from tensorflow.core.util import event_pb2
from tensorflow.data import TFRecordDataset
from tensorflow.python.summary.summary_iterator import summary_iterator
import tensorflow as tf

from pathlib import Path
from multiprocessing import Pool
import os
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import pprint
p = pprint.PrettyPrinter(indent=4)
os.chdir(Path(os.environ["MASTER"]))
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
""" Taken from https://gist.github.com/laszukdawid/62656cf7b34cac35b325ba21d46ecfcd

    https://laszukdawid.com/blog/2021/01/26/parsing-tensorboard-data-locally/
    
    Updated with https://stackoverflow.com/questions/58248787/reading-tf2-summary-file-with-tf-data-tfrecorddataset#58314091
"""


def convert_tfevent(filepath):
    return pd.DataFrame([
        parse_tfevent(e) for e in summary_iterator(filepath) if len(e.summary.value)
    ])


def parse_tfevent(tfevent):
    return dict(
        wall_time=tfevent.wall_time,
        name=tfevent.summary.value[0].tag,
        step=tfevent.step,
        value=float(tfevent.summary.value[0].simple_value),
    )


def convert_tb_data(root_dir, sort_by=None):
    """Convert local TensorBoard data into Pandas DataFrame.

    Function takes the root directory path and recursively parses
    all events data.    
    If the `sort_by` value is provided then it will use that column
    to sort values; typically `wall_time` or `step`.

    *Note* that the whole data is converted into a DataFrame.
    Depending on the data size this might take a while. If it takes
    too long then narrow it to some sub-directories.

    Paramters:
        root_dir: (str) path to root dir with tensorboard data.
        sort_by: (optional str) column name to sort by.

    Returns:
        pandas.DataFrame with [wall_time, name, step, value] columns.

    """
    columns_order = ['wall_time', 'name', 'step', 'value']

    out = []
    for (root, _, filenames) in os.walk(root_dir):
        for filename in filenames:
            if "events.out.tfevents" not in filename:
                continue
            file_full_path = os.path.join(root, filename)
            out.append(convert_tfevent(file_full_path))

    # Concatenate (and sort) all partial individual dataframes
    all_df = pd.concat(out)[columns_order]
    if sort_by is not None:
        all_df = all_df.sort_values(sort_by)

    return all_df.reset_index(drop=True)


In [None]:
def parallel(input_dir, output_dir):
    df = convert_tb_data(input_dir)
    return (f"{output_dir.parent.name}_{output_dir.name}", df)


def get_file_list():
    root_path = Path(os.environ["MASTER"], "save")
    dirs = [dir for dir in root_path.iterdir()]

    l = []
    for dir in dirs:
        for input_dir in Path(dir, "runs").iterdir():
            l.extend(input_dir.iterdir())
    return l

def filter_list(path_list, filters):
    l = []
    for f in filters:
        l.extend(filter(lambda x: f in str(x), path_list))
    return l

def get_dataframes(input_dirs):
    l = []
    for input_dir in input_dirs:
        output_dir = Path(os.environ["MASTER"],
                          "results", input_dir.parents[2].name, input_dir.name)
        if not output_dir.exists():
            output_dir.mkdir(parents=True)
        l.append((input_dir, output_dir))
    with Pool() as pool:
        result = pool.starmap(parallel, l)
    return {k: v for k, v in result}


In [None]:
l = get_file_list()
filters= [
    "2023-02-11_Word_level",
    "2023-02-11_Sentence_level"
]
# filtered_l = filter_list(l, filters)
# d = get_dataframes(filtered_l)
d=get_dataframes(l)

In [None]:
# p.pprint()
all_keys = [
    'Disentanglement/Interpretability',
    'Disentanglement/Mutual Information Gap',
    'Disentanglement/Separated Attribute Predictability',
    "Disentanglement/Spearman's Rank Correlation",
    'accuracy/training',
    'accuracy/validation',
    'loss_KLD/training',
    'loss_KLD/validation',
    'loss_KLD_batchwise/training',
    'loss_KLD_batchwise/validation',
    'loss_KLD_unscaled/training',
    'loss_KLD_unscaled/validation',
    'loss_KLD_unscaled_batchwise/training',
    'loss_KLD_unscaled_batchwise/validation',
    'loss_reconstruction/training',
    'loss_reconstruction/validation',
    'loss_regularization/training',
    'loss_regularization/validation',
    'loss_sum/training',
    'loss_sum/validation'
]

p.pprint(list(d.keys()))

for title in all_keys:
    for key, value in d.items():
        if "RegTrue" in key:
            marker="."
        else:
            marker="|"
        if "Word_level" in key:
            color="red"
        else:
            color="green"
        plt.plot(np.array(value[value["name"] == title]["value"]), marker=marker, color=color)
        plt.title(title)
    plt.show()