In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import optuna
import pandas as pd
import numpy as np
import plotly
from graphviz import Digraph

#pd.set_option('display.max_columns', 200)
#pd.set_option('display.max_rows', 200)

In [None]:
path   = 'optimized_graph'
epochs = 1

# Cteate study object

In [None]:
study = optuna.create_study(storage='sqlite:///' + path + '/optuna.db',
                            study_name='experiment01',
                            #pruner=pruner,
                            direction="minimize",
                            load_if_exists=True)

# Traial

In [None]:
df = study.trials_dataframe()
params_df = df.loc[:, df.columns.str.contains("params")].columns.tolist()
df = df[["number", "state", "value", "datetime_start", "datetime_complete", "user_attrs_seed"]+params_df]

In [None]:
df

In [None]:
font_size = 14
fig = plt.figure()
ax  = fig.add_subplot(111)
df[df["state"] == 'COMPLETE'].value.plot(ax=ax, grid=True, figsize=(8,5))
ax.set_xlabel('Trial', size=font_size)
ax.set_ylabel('Error rate', size=font_size)

In [None]:
complete = df[df["state"] == 'COMPLETE']
no_nodes = len(complete.loc[:, complete.columns.str.contains("model")].columns)
params   = complete[params_df]
model    = params.loc[:,params.columns.str.contains("model")]

model_acc = []
for id_ in complete["number"]:
    model_acc += [[pd.read_csv(f"./{path}/{id_:04d}/log/net{i}/epoch_log.csv", index_col="Unnamed: 0").at[epochs, "test_accuracy"] for i in range(no_nodes)]]
    
model_acc  = pd.DataFrame(model_acc, index=model.index, columns=[f"model_{i}_acc" for i in range(no_nodes)])
sorted_df  = complete.sort_values(by="value")
sorted_acc = sorted_df["value"]
sorted_df  = pd.concat([sorted_acc, model.loc[sorted_acc.index], model_acc.loc[sorted_acc.index]], axis=1).rename(columns={0:"max_accuracy"})
sorted_df

# Graph

In [None]:
top = 0

## Loss

In [None]:
loss = params.loc[:, params.columns.str.contains("loss")].loc[sorted_df.index[top]]
model_name = params.loc[:, params.columns.str.contains("model")].loc[sorted_df.index[top]]
wh = int(np.sqrt(len(loss)))
df_loss = pd.DataFrame(loss.values.reshape((wh,wh)), columns=model_name, index=model_name)
df_loss

## Gate

In [None]:
gate = params.loc[:, params.columns.str.contains("gate")].loc[sorted_df.index[top]]
df_gate = pd.DataFrame(df_loss.copy()).applymap(lambda x :None)

for gate_name, val in gate.to_dict().items():
    if len(gate_name.split("_")) == 4:
        params, source, target, _ = gate_name.split("_")
        df_gate.iloc[int(source), int(target)] = val
df_gate

## structure

In [None]:
edge_color = {
    "ThroughGate": "3", 
    "LinearGate": "1",
    "CorrectGate": "2",
}

G = Digraph(format="pdf", engine="dot")

# accuracy
acc = model_acc.loc[sorted_acc.index].iloc[top]

# node
for target in range(len(df_loss)):    
    G.node(f"{target+1}. "+df_loss.index[target]+f" ({np.round(acc[target], decimals=2)}%)",color='gray90', fillcolor='gray90', style='filled')
G.node(f"{1}. "+df_loss.index[0]+f" ({np.round(acc[0], decimals=2)}%)", color='pink', fillcolor='pink', style='radial')

# edge
for target in range(len(df_loss)):
    for source in range(len(df_loss)):
        gate = df_gate.iloc[target,source]
        if gate != "CutoffGate":
            if source == target:  # label -> model
                label = df_gate.iloc[target, source].replace("Gate","") # ThroughGate -> Through
                if gate == "CorrectGate":
                    gate  = "ThroughGate"
                    label = "Through"
                G.edge(f"{target}",
                       f"{target+1}. "+df_loss.index[target]+f" ({np.round(acc[target], decimals=2)}%)",
                       label=label, fontsize="13", fontcolor=edge_color[gate],
                       color=edge_color[gate], colorscheme="dark28")
                G.node(f"{target}", label="Label", color='white', style='filled')
            else:
                gate_name = df_gate.iloc[target, source].replace("Gate","") # ThroughGate -> Through
                loss_name = df_loss.iloc[target, source].replace("Loss","") 
                
                if   loss_name == "KL_P":
                    loss_name = "Prob(+)" #"SoftTarget(+)"
                elif loss_name == "KL_N":
                    loss_name = "Prob(-)" #"SoftTarget(-)"
                    
                elif loss_name == "Att_P":
                    loss_name = "Attention(+)"
                elif loss_name == "Att_N":
                    loss_name = "Attention(-)"
                
                elif loss_name == "KL_Att_P_P":
                    loss_name = "Prob(+), Attention(+)" #"SoftTarget(+), Attention(+)"
                elif loss_name == "KL_Att_N_N":
                    loss_name = "Prob(+), Attention(+)" #"SoftTarget(-), Attention(-)"
                
                label = loss_name+"\n("+gate_name+")"
                G.edge(f"{source+1}. "+df_loss.columns[source]+f" ({np.round(acc[source], decimals=2)}%)",
                       f"{target+1}. "+df_loss.index[target]+f" ({np.round(acc[target], decimals=2)}%)",
                       label=label, fontsize="13", fontcolor=edge_color[gate],
                       color=edge_color[gate], colorscheme="dark28")

for target in range(len(df_loss)):
    if (df_gate.iloc[target] == "CutoffGate").all():
        G.node(f"{target+1}. "+df_loss.columns[target]+f" ({np.round(acc[target], decimals=2)}%)",
               color='lightblue',fillcolor='lightblue', style='radial')

G.render(filename=f"{top}", directory="./topn_graph", cleanup=True, format="png")   #format="pdf" png
G

In [None]:
print("Top-1 graph :", str(sorted_df.index[top]).zfill(4))