In [10]:
import wandb
import plotly.express as px 
import plotly.graph_objects as go
import numpy as np 
import pandas as pd 
from tqdm import tqdm
import matplotlib.pyplot as plt 

In [11]:
api = wandb.Api()

In [12]:
runs = [run for run in api.runs('jlehrer1/Ablation Study') if run.state == 'finished']

In [18]:
help(runs[0])

Help on Run in module wandb.apis.public object:

3gwl3y7u = class Run(Attrs)
 |  3gwl3y7u(client, entity, project, run_id, attrs={})
 |  
 |  A single run associated with an entity and project.
 |  
 |  Attributes:
 |      tags ([str]): a list of tags associated with the run
 |      url (str): the url of this run
 |      id (str): unique identifier for the run (defaults to eight characters)
 |      name (str): the name of the run
 |      state (str): one of: running, finished, crashed, killed, preempting, preempted
 |      config (dict): a dict of hyperparameters associated with the run
 |      created_at (str): ISO timestamp when the run was started
 |      system_metrics (dict): the latest system metrics recorded for the run
 |      summary (dict): A mutable dict-like property that holds the current summary.
 |                  Calling update will persist any changes.
 |      project (str): the project associated with the run
 |      entity (str): the name of the entity associated wi

In [20]:
r = runs[0].scan_history(keys=[
    'val_loss_epoch',
    'val_weighted_accuracy',
    'val_median_f1'
])

In [26]:
pd.DataFrame(r)

Unnamed: 0,val_loss_epoch,val_weighted_accuracy,val_median_f1
0,2.949212,0.379011,0.000000
1,2.863232,0.434996,0.000000
2,1.612395,0.721363,0.484848
3,0.749035,0.910049,0.885864
4,0.378545,0.956547,0.971252
...,...,...,...
79,0.200611,0.974350,0.981535
80,0.167124,0.974116,0.981132
81,0.159982,0.977747,0.983678
82,0.170713,0.976107,0.985401


In [43]:
loss, wacc, bacc, f1 = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

for run in tqdm(runs):
    r = run.scan_history(keys=[
        'val_loss_epoch',
        'val_weighted_accuracy',
        'val_balanced_accuracy',
        'val_median_f1'
    ])
    
    df = pd.DataFrame(r)
    df = df.loc[0:50, :]
        
    loss[run.name] = df['val_loss_epoch']
    wacc[run.name] = df['val_weighted_accuracy']
    bacc[run.name] = df['val_balanced_accuracy']
    f1[run.name] = df['val_median_f1']


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:09<00:00,  7.70s/it]


In [47]:
loss = loss.sort_index(axis=1)
f1 = f1.sort_index(axis=1)
wacc = wacc.sort_index(axis=1)
bacc = bacc.sort_index(axis=1)

In [58]:
fig = go.Figure(
    layout=go.Layout(
        title='Validation Loss For Ablative Models',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Loss')
    )
)

for col in loss:
    fig.add_trace(
        go.Scatter(x=loss.index, y=loss[col], name=col)
    )

fig.show()

In [59]:
fig = go.Figure(
    layout=go.Layout(
        title='Median F1 Score For Ablative Models',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Median F1')
    )
)

for col in loss:
    fig.add_trace(
        go.Scatter(x=f1.index, y=f1[col], name=col)
    )

fig.show()

In [66]:
avg_f1 = f1.iloc[-5:, :].mean()

fig = go.Figure(
    data=go.Bar(
        x=avg_f1.index, y=avg_f1.values
    ),
    layout=go.Layout(
        title='Average of Median F1 over 5 Final Epochs (Validation Set)',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Median F1')
    )
)

fig.show()

In [67]:
fig = go.Figure(
    layout=go.Layout(
        title='Weighted Accuracy For Ablative Models',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Loss')
    )
)

for col in loss:
    fig.add_trace(
        go.Scatter(x=wacc.index, y=wacc[col], name=col)
    )

fig.show()

In [17]:
train_loss = []
val_loss = []
for x in tqdm(df):
    print(x.keys())
    if 'train_loss_step' in x:
        train_loss.append(x['train_loss_step'])
    
    if 'val_loss_step' in x:
        val_loss.append(x['val_loss_step'])

1it [00:00,  4.55it/s]

dict_keys(['trainer/global_step', '_step', '_runtime', 'lr-Adam', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'train_loss_step', 'epoch', '_timestamp'])
dict_keys(['trainer/global_step', '_s

1001it [00:00, 2525.36it/s]

dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', 

2790it [00:00, 3988.04it/s]

dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', '_runtime', 'val_loss_step', '_timestamp'])
dict_keys(['trainer/global_step', '_step', 




In [15]:
len(train_loss), len(val_loss)

(151, 2376)

In [9]:
px.line(val_loss)

In [54]:
run.history(samples=10000).shape

(1491, 50)

In [59]:
runs = api.runs('jlehrer1/mouse_data Cortical Model')

for run in runs:
    df = run.history()
    print(df.shape)
    print(df.isna().sum())
#     fig = px.line(df, x=df.index, y='train_loss_step')
#     fig.show()
        

(472, 21)
trainer/global_step            0
_step                          0
_runtime                       0
val_loss_step                 40
_timestamp                     0
train_loss_step              440
epoch                        438
train_precision              470
train_f1                     470
train_auroc                  470
train_specificity            470
train_loss_epoch             470
train_confusion_matrix       470
train_weighted_accuracy      470
train_balanced_accuracy      470
train_total_accuracy         470
train_per_class_recall       470
train_per_class_precision    470
train_recall                 470
train_per_class_f1           470
test_loss_step               466
dtype: int64
(511, 22)
trainer/global_step          0
_step                        0
_runtime                     0
train_loss_step            474
epoch                      472
_timestamp                   0
val_loss_step               41
val_median_f1              509
val_auroc                 

In [72]:
d = np.ones([5, 5])
d[0, 0:3] = [5,5, 5]

d

array([[5., 5., 5., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]])

In [37]:
summary_list, config_list, name_list = [], [], []
for run in runs:
    if run.state == 'finished':
        df = run.history()
        
        px.scatter(df, x='trainer/global_step', y='train_loss_step').show()

ValueError: Value of 'y' is not the name of a column in 'data_frame'. Expected one of ['trainer/global_step', '_step', '_runtime', 'test_loss_step', '_timestamp', 'test_auroc', 'test_per_class_recall', 'test_specificity', 'test_total_accuracy', 'test_confusion_matrix', 'test_per_class_precision', 'test_median_f1', 'epoch', 'test_f1', 'test_weighted_accuracy', 'test_per_class_f1', 'test_recall', 'test_balanced_accuracy', 'test_precision', 'test_loss_epoch'] but received: train_loss_step

In [20]:
config_list

[{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]

In [17]:
import pandas as pd 
import wandb

api = wandb.Api()
entity, project = "<entity>", "<project>"  # set to your entity and project 
runs = api.runs(entity + "/" + project) 

summary_list, config_list, name_list = [], [], []
for run in runs: 
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "name": name_list
    })

runs_df.to_csv("project.csv")

Help on Run in module wandb.apis.public object:

q9wg97g0 = class Run(Attrs)
 |  q9wg97g0(client, entity, project, run_id, attrs={})
 |  
 |  A single run associated with an entity and project.
 |  
 |  Attributes:
 |      tags ([str]): a list of tags associated with the run
 |      url (str): the url of this run
 |      id (str): unique identifier for the run (defaults to eight characters)
 |      name (str): the name of the run
 |      state (str): one of: running, finished, crashed, killed, preempting, preempted
 |      config (dict): a dict of hyperparameters associated with the run
 |      created_at (str): ISO timestamp when the run was started
 |      system_metrics (dict): the latest system metrics recorded for the run
 |      summary (dict): A mutable dict-like property that holds the current summary.
 |                  Calling update will persist any changes.
 |      project (str): the project associated with the run
 |      entity (str): the name of the entity associated wi