In [169]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objs as go

In [170]:
## from optuna source code, modified
def get_parallel_coordinate_plot(title, dataframe, cols, target_col, reverse_cols=[]):
    layout = go.Layout(title=title, autosize=True)
       
    dims = []

    for col_name in cols:
        values = dataframe[col_name].values
        if type(values[0]) == str:
            unique = np.unique(values)
            name_dict = { name: ind for ind, name in enumerate(unique) }
            
            dims.append({
                "label": col_name,
                "values": [name_dict[name] for name in values],
                "range": [0, len(unique)],
                "tickvals": list(range(len(unique))),
                "ticktext": unique
            })
            
            continue
            
        dims.append({
                "tickvals": [np.around(val, 3) for val in np.linspace(np.min(values), np.max(values), 5)] if col_name not in reverse_cols \
            else [np.around(val, 3) for val in np.linspace(np.max(values), np.min(values), 5)],
                "label": ' '.join(col_name.split('_')),
                "values": tuple(values),
                "range": (np.min(values), np.max(values)) if col_name not in reverse_cols else (np.max(values), np.min(values)),
        })

    dims.append({
        "label": '',
        "values": tuple(dataframe[target_col].values),
        "range": (np.around(np.min(dataframe[target_col].values), 3),
                  np.around(np.max(dataframe[target_col].values), 3)),
    })
        
    traces = [
        go.Parcoords(
            tickfont=dict(
                family="Courier New, bold",
                size=25),
            dimensions=dims,
            labelangle=45,
            labelfont=dict(
                family="Courier New",
                size=35
            ),
            labelside="bottom",
            line={
                "color": dims[-1]['values'],
                "colorscale": 'Jet',
                "colorbar": {"title": "BA (valid)"},
                "showscale": True,
                "reversescale": False
            },
        )
    ]

    figure = go.Figure(data=traces, layout=layout)

    return figure

In [171]:
def get_relevant_dataframe(path, target_col):
    df = pd.read_csv(path)

    ## remove 0 validation accuracy
    df = df.loc[df[target_col] > 0.]
    
    return df

In [184]:
image_only_df = get_relevant_dataframe(
    '/mnt/ncshare/ai4covid_hackathon/paper_utils/image_only_hyperparam_tuning_w_ES.csv',
    target_col='val_balanced_accuracy')

image_only_df.head()

Unnamed: 0,Name,backbone,batch_size,dataset_identifier,death_rate,epochs,img_size,last_dense_size,learning_rate,steps_per_epoch,val_balanced_accuracy
0,vital-sweep-409,ResNet50,8,population_average,0.359178,48,366,118,0.015183,541,0.638009
1,light-sweep-408,ResNet50,11,population_average,0.493626,36,475,77,0.013653,535,0.624434
2,amber-sweep-407,ResNet50,8,population_sampled,0.402759,38,483,73,0.010951,594,0.669683
3,dry-sweep-406,ResNet50,6,hospital_average,0.476867,34,266,110,0.009888,504,0.660634
5,eternal-sweep-404,InceptionV3,8,population_average,0.2932,40,479,64,0.009136,350,0.61991


In [185]:
relevant_cols = ['backbone', 'last_dense_size', 'batch_size', 'death_rate',
                 'epochs', 'img_size', 'learning_rate',
                 'steps_per_epoch']

target_col = 'val_balanced_accuracy'

In [186]:
fig = get_parallel_coordinate_plot('Image only', image_only_df, relevant_cols, target_col, reverse_cols=["epochs"])

fig.update_layout(
    font=dict(
        family="Courier New, bold",
        size=35
    ), 
    width=2500,
    height=800,
    overwrite=True,
    margin=dict(l=300, b=300)
)

fig.show()

In [187]:
image_meta_df = get_relevant_dataframe(
    '/mnt/ncshare/ai4covid_hackathon/paper_utils/image_and_meta_hyperparam_tuning_w_ES.csv',
    target_col='val_balanced_accuracy')

image_meta_df.head()

Unnamed: 0,Name,backbone,batch_size,dataset_identifier,death_rate,epochs,img_size,last_dense_size,learning_rate,steps_per_epoch,val_balanced_accuracy
0,ethereal-sweep-895,EfficientNetB1,9,population_average,0.25926,30,383,52,0.01395,360,0.687783
1,brisk-sweep-894,ResNet50,8,population_average,0.447113,37,310,60,0.014459,447,0.674208
2,swept-sweep-893,EfficientNetB0,8,population_average,0.294745,36,416,76,0.015767,492,0.791855
3,colorful-sweep-892,ResNet50,6,population_average,0.367993,36,484,60,0.018318,487,0.687783
4,splendid-sweep-891,EfficientNetB0,7,population_average,0.160637,37,296,106,0.014676,469,0.701357


In [188]:
image_meta_df.columns

Index(['Name', 'backbone', 'batch_size', 'dataset_identifier', 'death_rate',
       'epochs', 'img_size', 'last_dense_size', 'learning_rate',
       'steps_per_epoch', 'val_balanced_accuracy'],
      dtype='object')

In [189]:
relevant_cols = ['dataset_identifier', 'backbone', 'batch_size', 'death_rate',
       'epochs', 'img_size', 'last_dense_size', 'learning_rate',
       'steps_per_epoch']

fig = get_parallel_coordinate_plot('Image and meta', image_meta_df,
                                   relevant_cols, target_col,
                                   reverse_cols=["batch_size", "steps_per_epoch", "last_dense_size", "epochs"])

fig.update_layout(
    font=dict(
        family="Courier New, bold",
        size=35
    ), 
    width=2500,
    height=800,
    overwrite=True,
    margin=dict(l=300, b=300)
)

fig.show()