In [1]:
import wandb
import random

import numpy as np

import ray
from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb

In [2]:
def train_function(config):
    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        train.report({"loss": loss})

In [3]:
def tune_with_callback():
    """Example for using a WandbLoggerCallback with the function API"""
    tuner = tune.Tuner(
        train_function,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        run_config=train.RunConfig(
            callbacks=[WandbLoggerCallback(project="Wandb_example")]
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    tuner.fit()

In [4]:
def train_function_wandb(config):
    wandb = setup_wandb(config, project="Wandb_example")

    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        train.report({"loss": loss})
        wandb.log(dict(loss=loss))

In [5]:
def tune_with_setup():
    """Example for using the setup_wandb utility with the function API"""
    tuner = tune.Tuner(
        train_function_wandb,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    tuner.fit()

In [6]:
class WandbTrainable(tune.Trainable):
    def setup(self, config):
        self.wandb = setup_wandb(
            config,
            trial_id=self.trial_id,
            trial_name=self.trial_name,
            group="Example",
            project="Wandb_example",
        )

    def step(self):
        for i in range(30):
            loss = self.config["mean"] + self.config["sd"] * np.random.randn()
            self.wandb.log({"loss": loss})
        return {"loss": loss, "done": True}

    def save_checkpoint(self, checkpoint_dir: str):
        pass

    def load_checkpoint(self, checkpoint_dir: str):
        pass

In [7]:
def tune_trainable():
    """Example for using a WandTrainableMixin with the class API"""
    tuner = tune.Tuner(
        WandbTrainable,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    results = tuner.fit()

    return results.get_best_result().config


In [17]:
import os

mock_api = False

if mock_api:
    os.environ.setdefault("WANDB_MODE", "disabled")
    os.environ.setdefault("WANDB_API_KEY", "abcd")
    ray.init(
        runtime_env={"env_vars": {"WANDB_MODE": "disabled", "WANDB_API_KEY": "abcd"}}
    )

tune_with_callback()
tune_with_setup()
tune_trainable()

0,1
Current time:,2025-01-25 09:38:28
Running for:,00:00:06.65
Memory:,6.2/8.0 GiB

Trial name,status,loc,mean,sd,iter,total time (s),loss
WandbTrainable_6831c_00000,TERMINATED,127.0.0.1:64142,1,0.279629,1,0.00273991,1.07313
WandbTrainable_6831c_00001,TERMINATED,127.0.0.1:64143,2,0.474353,1,0.00225115,2.32919
WandbTrainable_6831c_00002,TERMINATED,127.0.0.1:64144,3,0.470929,1,0.00273681,2.70396
WandbTrainable_6831c_00003,TERMINATED,127.0.0.1:64145,4,0.277263,1,0.00398517,3.71797
WandbTrainable_6831c_00004,TERMINATED,127.0.0.1:64146,5,0.63165,1,0.00171804,5.36523


2025-01-25 09:38:28,131	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/hongsupshin/ray_results/WandbTrainable_2025-01-25_09-38-21' in 0.0205s.
2025-01-25 09:38:28,144	INFO tune.py:1041 -- Total run time: 6.68 seconds (6.63 seconds for the tuning loop).


{'mean': 1, 'sd': 0.2796293119514601}

In [8]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mauth0-1wgih[0m ([33mauth0-1wgih-hs[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
# api = wandb.Api()

In [15]:
# run = wandb.init()

In [19]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="my-awesome-project",

#     # track hyperparameters and run metadata
#     config={
#     "learning_rate": 0.02,
#     "architecture": "CNN",
#     "dataset": "CIFAR-100",
#     "epochs": 10,
#     }
# )


In [18]:
# # simulate training
# epochs = 10
# offset = random.random() / 5
# for epoch in range(2, epochs):
#     acc = 1 - 2 ** -epoch - random.random() / epoch - offset
#     loss = 2 ** -epoch + random.random() / epoch + offset

#     # log metrics to wandb
#     wandb.log({"acc": acc, "loss": loss})

# # [optional] finish the wandb run, necessary in notebooks
# wandb.finish()