In [None]:
import os
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv
from tqdm.auto import tqdm

import wandb

In [None]:
curr_path = Path(os.getcwd())
env_path = curr_path.parent.absolute() / ".env"

load_dotenv(dotenv_path=env_path)

In [None]:
def get_data(projects=["ccm_project"], entity="jpetty"):

    api = wandb.Api(api_key=os.environ["WANDB_API_KEY"])

    runs_dfs = []
    for project in projects:

        runs = api.runs(entity + "/" + project)

        summary_list, config_list, name_list = [], [], []
        run_hashes = []
        for run in tqdm(runs):
            # .summary contains output keys/values for
            # metrics such as accuracy.
            #  We call ._json_dict to omit large files
            summary_list.append(run.summary._json_dict)

            # .config contains the hyperparameters.
            #  We remove special values that start with _.
            config_list.append(
                {k: v for k, v in run.config.items() if not k.startswith("_")}
            )

            # .name is the human-readable name of the run.
            name_list.append(run.name)
            if len(run.logged_artifacts()) > 0:
                for table in run.logged_artifacts():
                    table_dir = table.download()
                    run_hash = Path(table_dir).name
                    if "vocab" in run_hash:
                        run_hashes.append({"name": run.name, "hash": run_hash})
            #   table = run.logged_artifacts()[0]
            #   table_dir = table.download()
            #   run_hash = Path(table_dir).name
            #   print(run_hash)
            #   table_name = "vocab"
            #   table_path = f"{table_dir}/{table_name}.table.json"
            #   print(table_path)

            #   with open(table_path) as file:
            #     json_dict = json.load(file)
            #     df = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
            #   print(df)
            #   raise SystemExit
            # run_artifacts = run.logged_artifacts()
            # for art in run_artifacts:
            #    print(art)
            # print(run.summary._json_dict)
            # raise SystemExit

        runs_df = pd.DataFrame(
            {"summary": summary_list, "config": config_list, "name": name_list}
        )

        # print(run_hashes)
        run_hash_df = pd.DataFrame.from_dict(run_hashes)
        # print(run_hash_df)

        runs_dfs.append(runs_df)

    good_names = ["colorful-morning-3"]
    runs_dfs = [x for x in runs_dfs if x["name"].isin(good_names).any()]

    runs_df = pd.concat(runs_dfs, ignore_index=True)
    runs_df = pd.merge(runs_df, run_hash_df, on="name")

    summary_df = pd.json_normalize(runs_df["summary"])
    config_df = pd.json_normalize(runs_df["config"])

    runs_df = pd.concat(
        [runs_df.drop(["summary", "config"], axis=1), summary_df, config_df], axis=1
    )

    return runs_df

In [None]:
runs_df = get_data()

In [None]:
runs_df.info()

In [None]:
runs_df.columns

In [None]:
runs_df.head()