In [6]:
import json
import plotly.express as px
import dash
from jupyter_dash import JupyterDash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import glob

In [7]:
f_names = glob.glob("./**", recursive=True)

In [8]:
f_names = [fn for fn in f_names if "trainer_state" in fn]

In [9]:
f_names

['./training_states/6184d3d8-3de9-4b16-b11d-7f7bd8ec1821_TransformerPlusTemporal/checkpoint-60000/trainer_state.json',
 './training_states/6184d3d8-3de9-4b16-b11d-7f7bd8ec1821_TransformerPlusTemporal/checkpoint-40000/trainer_state.json',
 './training_states/6184d3d8-3de9-4b16-b11d-7f7bd8ec1821_TransformerPlusTemporal/checkpoint-100000/trainer_state.json',
 './training_states/6184d3d8-3de9-4b16-b11d-7f7bd8ec1821_TransformerPlusTemporal/checkpoint-80000/trainer_state.json',
 './training_states/3bfa095a-7fe5-42aa-8d89-e780c3933abb_TransformerPlusTemporal/checkpoint-6000/trainer_state.json',
 './training_states/3bfa095a-7fe5-42aa-8d89-e780c3933abb_TransformerPlusTemporal/checkpoint-9000/trainer_state.json',
 './training_states/3bfa095a-7fe5-42aa-8d89-e780c3933abb_TransformerPlusTemporal/checkpoint-15000/trainer_state.json',
 './training_states/3bfa095a-7fe5-42aa-8d89-e780c3933abb_TransformerPlusTemporal/checkpoint-12000/trainer_state.json',
 './training_states/a530e81d-59f1-4d2e-b1a8-8025d

In [10]:
test_results = {
    "Roberta-likeFu": {"hash": "6184d3d8-3de9-4b16-b11d-7f7bd8ec1821", "p": 0.67, "r": 0.24, "f": 0.36},
    "TLB-500-likeFu": {"hash": "1f61ff31-6eba-45ea-a9b0-64e0a55e6417", "p": 0.97, "r": 0.11, "f": 0.20},
    "TLB-1000-likeFu": {"hash": "f8572cc1-c3a8-4c48-9fbd-546d50c91a47", "p": 0.84,"r": 0.13,"f": 0.22},
    "Roberta-plusTemporal": {"hash": "3bfa095a-7fe5-42aa-8d89-e780c3933abb", "p": 1.0, "r": 0.15, "f": 0.26},
    "TLB-500-plusTemporal": {"hash": "3d32b82c-8880-4cc6-8041-e51fe9e3ca2c", "p": 1.0, "r": 0.15, "f": 0.26},
    "TLB-1000-plusTemporal": {"hash": "dfbf73de-a3e3-48a6-841a-3e926031e249", "p": 1.0, "r": 0.15, "f": 0.26}
}

In [27]:
training_data = []
for name, t in test_results.items():
    training_files = {fn.split("/")[-2].split("-")[-1]: fn for fn in f_names if t["hash"] in fn}
    mx_trf = max(training_files.keys())
    d = {"name": name,
         "training_file": training_files[mx_trf]}
    d.update(t)
    training_data.append(d)

In [28]:
training_data

[{'name': 'Roberta-likeFu',
  'training_file': './training_states/6184d3d8-3de9-4b16-b11d-7f7bd8ec1821_TransformerPlusTemporal/checkpoint-80000/trainer_state.json',
  'hash': '6184d3d8-3de9-4b16-b11d-7f7bd8ec1821',
  'p': 0.67,
  'r': 0.24,
  'f': 0.36},
 {'name': 'TLB-500-likeFu',
  'training_file': './training_states/1f61ff31-6eba-45ea-a9b0-64e0a55e6417_TransformerLikeFu/checkpoint-80000/trainer_state.json',
  'hash': '1f61ff31-6eba-45ea-a9b0-64e0a55e6417',
  'p': 0.97,
  'r': 0.11,
  'f': 0.2},
 {'name': 'TLB-1000-likeFu',
  'training_file': './training_states/f8572cc1-c3a8-4c48-9fbd-546d50c91a47_TransformerLikeFu/checkpoint-80000/trainer_state.json',
  'hash': 'f8572cc1-c3a8-4c48-9fbd-546d50c91a47',
  'p': 0.84,
  'r': 0.13,
  'f': 0.22},
 {'name': 'Roberta-plusTemporal',
  'training_file': './training_states/3bfa095a-7fe5-42aa-8d89-e780c3933abb_TransformerPlusTemporal/checkpoint-9000/trainer_state.json',
  'hash': '3bfa095a-7fe5-42aa-8d89-e780c3933abb',
  'p': 1.0,
  'r': 0.15,
  

In [38]:
# Build App
app = JupyterDash(__name__)
# Create server variable with Flask server object for use with gunicorn
server = app.server

app.layout = html.Div([
    html.H1("Model training progress"),
    # graph component
    dcc.Graph(id='graph'),
    html.Label([
        "choose model training checkpoint",
        dcc.Dropdown(
            id='model-select', clearable=False,
            value=training_data[0]["hash"], options=[{'label': d["name"], 'value': d["hash"]} for d in training_data])
    ]),
       # output evaluation metrics
    html.Div([
        "Evaluation on test set",
        html.Div(id="results")
    ])

])

# Define callback to update graph
@app.callback(
    Output('graph', 'figure'),
    Output('results', 'children'),
    [Input("model-select", "value")]
)
def update_figure(run_hash):
    # check which component triggered the update    
    """
    x = list(range(len(hl_msg_data[match1]["msg_dens"][::scale])))
    msg_dens = hl_msg_data[match1]["msg_dens"][::scale]
    spp = ScipyPeaks(**ssp_params).predict(msg_dens)
    """

    tr_dat = [td for td in training_data if td["hash"] == run_hash][0]
    
    with open(tr_dat["training_file"], "r") as in_file:
        data = json.load(in_file)
    
    epochs_train = [elm["epoch"] for elm in data["log_history"] if "loss" in elm]
    loss_train = [elm["loss"] for elm in data["log_history"] if "loss" in elm]
    
    epochs_val = [elm["epoch"] for elm in data["log_history"] if "eval_loss" in elm]
    loss_val = [elm["eval_loss"] for elm in data["log_history"] if "eval_loss" in elm]
    f1_val = [elm["eval_f1"] for elm in data["log_history"] if "eval_loss" in elm]

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=epochs_train, y=loss_train,
                        mode='lines',
                        name='training loss'))
    
    fig.add_trace(go.Scatter(x=epochs_val, y=loss_val,
                        mode='lines',
                        name='validation loss'))
    
    fig.add_trace(go.Scatter(x=epochs_val, y=f1_val,
                        mode='lines',
                        name='vaildation f1'))
    
    eval_result = html.Table([
                    html.Tr([html.Th('Precision'),
                             html.Th('Recall'),
                             html.Th('F-score'),
                             #html.Th('Accuracy')
                            ]),
                    html.Tr([html.Td("{:.5f}".format(tr_dat["p"])),
                             html.Td("{:.5f}".format(tr_dat["r"])),
                             html.Td("{:.5f}".format(tr_dat["f"])),
                             #html.Td("{:.5f}".format(evaluation["accuracy"]))
                            ])
    ])
    
    return fig, eval_result

In [39]:
# Run app and display result inline in the notebook
app.run_server(mode='jupyterlab')


The 'environ['werkzeug.server.shutdown']' function is deprecated and will be removed in Werkzeug 2.1.



In [18]:
with open("./training_states/1f61ff31-6eba-45ea-a9b0-64e0a55e6417_TransformerLikeFu/checkpoint-80000/trainer_state.json", "r") as in_file:
    data = json.load(in_file)

In [20]:
data["log_history"]

[{'epoch': 0.01,
  'learning_rate': 1.9973901923428246e-05,
  'loss': 0.1753,
  'step': 500},
 {'epoch': 0.03,
  'learning_rate': 1.9947803846856487e-05,
  'loss': 0.1542,
  'step': 1000},
 {'epoch': 0.04,
  'learning_rate': 1.992170577028473e-05,
  'loss': 0.1456,
  'step': 1500},
 {'epoch': 0.05,
  'learning_rate': 1.9895607693712976e-05,
  'loss': 0.1336,
  'step': 2000},
 {'epoch': 0.07,
  'learning_rate': 1.9869509617141217e-05,
  'loss': 0.1283,
  'step': 2500},
 {'epoch': 0.08,
  'learning_rate': 1.984341154056946e-05,
  'loss': 0.1222,
  'step': 3000},
 {'epoch': 0.09,
  'learning_rate': 1.9817313463997706e-05,
  'loss': 0.1197,
  'step': 3500},
 {'epoch': 0.1,
  'learning_rate': 1.9791215387425947e-05,
  'loss': 0.1123,
  'step': 4000},
 {'epoch': 0.12,
  'learning_rate': 1.976511731085419e-05,
  'loss': 0.1139,
  'step': 4500},
 {'epoch': 0.13,
  'learning_rate': 1.9739019234282436e-05,
  'loss': 0.1004,
  'step': 5000},
 {'epoch': 0.14,
  'learning_rate': 1.9712921157710677e

In [12]:
[elm["epoch"] for elm in data["log_history"] if "loss" in elm]
[elm["loss"] for elm in data["log_history"] if "loss" in elm]

[0.1753,
 0.1542,
 0.1456,
 0.1336,
 0.1283,
 0.1222,
 0.1197,
 0.1123,
 0.1139,
 0.1004,
 0.098,
 0.096,
 0.0917,
 0.0872,
 0.0869,
 0.0807,
 0.0766,
 0.0758,
 0.07,
 0.07,
 0.0677,
 0.065,
 0.065,
 0.0632,
 0.0595,
 0.0601,
 0.0603,
 0.054,
 0.055,
 0.0538,
 0.0507,
 0.0513,
 0.0505,
 0.0481,
 0.0451,
 0.0476,
 0.0424,
 0.0432,
 0.0443,
 0.0418,
 0.0428,
 0.0407,
 0.0404,
 0.0346,
 0.0405,
 0.0377,
 0.0397,
 0.0408,
 0.0376,
 0.0341,
 0.0354,
 0.0366,
 0.0371,
 0.0348,
 0.0323,
 0.0311,
 0.0331,
 0.0283,
 0.0295,
 0.0307,
 0.0294,
 0.0292,
 0.0322,
 0.0288,
 0.0278,
 0.0231,
 0.0283,
 0.0278,
 0.0249,
 0.0246,
 0.0244,
 0.0272,
 0.0245,
 0.0251,
 0.0259,
 0.026,
 0.0238,
 0.0189,
 0.0174,
 0.0211,
 0.0177,
 0.0204,
 0.0179,
 0.0171,
 0.0155,
 0.0182,
 0.0166,
 0.0195,
 0.0171,
 0.0189,
 0.018,
 0.0174,
 0.0168,
 0.0181,
 0.0163,
 0.0166,
 0.0197,
 0.0154,
 0.0178,
 0.0158,
 0.0157,
 0.0161,
 0.0173,
 0.0146,
 0.0147,
 0.0151,
 0.0143,
 0.0164,
 0.0163,
 0.0144,
 0.0165,
 0.0139,
 0.0