In [9]:
import numpy as np

import ray
from ray import tune
from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb

In [11]:
def train_function(config):
    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        tune.report({"loss": loss})

In [12]:
def tune_with_callback():
    """Example for using a WandbLoggerCallback with the function API"""
    tuner = tune.Tuner(
        train_function,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        run_config=tune.RunConfig(
            callbacks=[WandbLoggerCallback(project="Wandb_example")]
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    tuner.fit()

In [13]:
def train_function_wandb(config):
    wandb = setup_wandb(config, project="Wandb_example")

    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        tune.report({"loss": loss})
        wandb.log(dict(loss=loss))

In [14]:
def tune_with_setup():
    """Example for using the setup_wandb utility with the function API"""
    tuner = tune.Tuner(
        train_function_wandb,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    tuner.fit()

In [15]:
class WandbTrainable(tune.Trainable):
    def setup(self, config):
        self.wandb = setup_wandb(
            config,
            trial_id=self.trial_id,
            trial_name=self.trial_name,
            group="Example",
            project="Wandb_example",
        )

    def step(self):
        for i in range(30):
            loss = self.config["mean"] + self.config["sd"] * np.random.randn()
            self.wandb.log({"loss": loss})
        return {"loss": loss, "done": True}

    def save_checkpoint(self, checkpoint_dir: str):
        pass

    def load_checkpoint(self, checkpoint_dir: str):
        pass

In [16]:
def tune_trainable():
    """Example for using a WandTrainableMixin with the class API"""
    tuner = tune.Tuner(
        WandbTrainable,
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
        ),
        param_space={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
    )
    results = tuner.fit()

    return results.get_best_result().config

In [None]:
import os

mock_api = True

if mock_api:
    os.environ.setdefault("WANDB_MODE", "disabled")
    os.environ.setdefault("WANDB_API_KEY", "abcd")
    ray.init(
        runtime_env={"env_vars": {"WANDB_MODE": "disabled", "WANDB_API_KEY": "abcd"}}
    )

tune_with_callback()
tune_with_setup()
tune_trainable()

0,1
Current time:,2025-06-27 00:38:59
Running for:,00:00:01.27
Memory:,38.1/125.5 GiB

Trial name,status,loc,mean,sd,iter,total time (s),loss
WandbTrainable_acd88_00000,TERMINATED,192.168.0.25:890315,1,0.274597,1,4.00543e-05,0.779812
WandbTrainable_acd88_00001,TERMINATED,192.168.0.25:890317,2,0.676938,1,4.17233e-05,2.14902
WandbTrainable_acd88_00002,TERMINATED,192.168.0.25:890316,3,0.557322,1,4.48227e-05,3.32984
WandbTrainable_acd88_00003,TERMINATED,192.168.0.25:890318,4,0.491615,1,4.26769e-05,4.34077
WandbTrainable_acd88_00004,TERMINATED,192.168.0.25:890314,5,0.257322,1,3.40939e-05,4.75256


2025-06-27 00:38:59,424	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/qrbao/ray_results/WandbTrainable_2025-06-27_00-38-58' in 0.0041s.
2025-06-27 00:38:59,427	INFO tune.py:1041 -- Total run time: 1.28 seconds (1.27 seconds for the tuning loop).


{'mean': 1, 'sd': 0.27459703325174134}

[36m(WandbTrainable pid=890316)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/qrbao/ray_results/WandbTrainable_2025-06-27_00-38-58/WandbTrainable_acd88_00002_2_mean=3,sd=0.5573_2025-06-27_00-38-58/checkpoint_000000)
[33m(raylet)[0m [2025-06-27 00:55:42,762 E 887077 887109] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-27_00-38-48_678663_882799 is over 95% full, available space: 22.8069 GB; capacity: 456.175 GB. Object creation will fail if spilling is required.
[36m(WandbTrainable pid=890314)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/qrbao/ray_results/WandbTrainable_2025-06-27_00-38-58/WandbTrainable_acd88_00004_4_mean=5,sd=0.2573_2025-06-27_00-38-58/checkpoint_000000)[32m [repeated 4x across cluster][0m
[33m(raylet)[0m [2025-06-27 00:55:52,767 E 887077 887109] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-27_00-38-48_678663_882799 is over 95% full, available space: 22.7943 G