In [2]:
import d3rlpy
import minari
import logging
import numpy as np
import matplotlib.pyplot as plt
from d3rlpy.dataset import MDPDataset
from d3rlpy.metrics import EnvironmentEvaluator

logging.basicConfig(level=logging.DEBUG)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def train():
    # Load dataset
    dataset = minari.load_dataset("D4RL/antmaze/medium-play-v1")
    episodes = list(dataset.iterate_episodes())
    print("Number of episodes:", len(episodes))

    # Extract data
    # with properties 'id', 'observations', 'actions', 'rewards', 'terminations', 'truncations', 'infos']
    observations = np.vstack([ep.observations["observation"] for ep in episodes])
    actions = np.vstack([ep.actions for ep in episodes])
    rewards = np.hstack([ep.rewards for ep in episodes])
    terminals = np.hstack([ep.terminations for ep in episodes])
    truncations = np.hstack([ep.truncations for ep in episodes])
    print("Number of observations in episode 0:", len(episodes[0].observations["observation"]))
    print("Number of actions in episode 0:", len(episodes[0].actions))
    assert len(episodes[0].observations["observation"]) == len(episodes[0].actions) + 1

    # Convert to d3rlpy format
    mdp_dataset = MDPDataset(observations, actions, rewards, terminals, truncations)
    print("Dataset size:", len(mdp_dataset.episodes))
    # import pdb; pdb.set_trace()
    print("MDP Dataset Episodes:", len(mdp_dataset.episodes))
    assert len(mdp_dataset.episodes) > 0, "ERROR: No episodes found in MDPDataset!"
    print("Sample episode:", mdp_dataset.episodes[:5])  # Print first 5 episodes

    # Initialize model
    awac = d3rlpy.algos.AWACConfig().create(device="mps")
    print("model initialized:", awac)

    # Ensure the dataset is not empty
    if len(mdp_dataset.episodes) == 0:
        raise ValueError("The MDPDataset is empty. Please check the dataset loading process.")


    # for debugging 
    from d3rlpy.datasets import get_pendulum
    mdp_dataset, env = get_pendulum()
    env_evaluator = EnvironmentEvaluator(env)
    print("TEST Dataset loaded successfully!")

    # Run one training step to check if it works
    awac.build_with_dataset(mdp_dataset)
    history = awac.fit(mdp_dataset, n_steps=200, n_steps_per_epoch=100, show_progress=True, evaluators={'environment': env_evaluator})
    # Use `fitter()` for step-by-step training
    # for epoch, metrics in enumerate(awac.fitter(mdp_dataset, n_steps=5000, show_progress=True, evaluators={'environment': env_evaluator})):
    #     print(f"Epoch {epoch}: {metrics}")
    #     # Optionally, break early if training starts logging NaNs
    #     if np.isnan(metrics["loss"]):
    #         print("NaN detected, stopping training early.")
    #         break
    print("Done")

    return history

In [4]:
def plot(history):

    epochs = list(range(len(history["critic_loss"])))

    plt.figure(figsize=(12, 4))

    # Plot Critic Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, history["critic_loss"], label="Critic Loss", color="blue")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Critic Loss Over Training")
    plt.legend()

    # Plot Actor Loss
    plt.subplot(1, 3, 2)
    plt.plot(epochs, history["actor_loss"], label="Actor Loss", color="red")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Actor Loss Over Training")
    plt.legend()

    # Plot Environment Rewards
    plt.subplot(1, 3, 3)
    plt.plot(epochs, history["environment"], label="Environment Reward", color="green")
    plt.xlabel("Epochs")
    plt.ylabel("Reward")
    plt.title("Reward Over Training")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [5]:
history = train()

Number of episodes: 1000
Number of observations in episode 0: 1001
Number of actions in episode 0: 1000
[2m2025-02-05 15:41.24[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('float32')], shape=[(8,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(27,)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2025-02-05 15:41.24[0m [[32m[1minfo     [0m] [1mAction-space has been automatically determined.[0m [36maction_space[0m=[35m<ActionSpace.CONTINUOUS: 1>[0m
[2m2025-02-05 15:41.24[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m8[0m
Dataset size: 1000
MDP Dataset Episodes: 1000
Sample episode: [Episode(observations=array([[ 7.50000000e-01,  1.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 7.

Epoch 1/2: 100%|██████████| 100/100 [00:04<00:00, 24.70it/s, critic_loss=10.6, actor_loss=1.37e+5, temp=0, temp_loss=0]
  if not isinstance(terminated, (bool, np.bool8)):


[2m2025-02-05 15:41.32[0m [[32m[1minfo     [0m] [1mAWAC_20250205154126: epoch=1 step=100[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.007736170291900634, 'time_algorithm_update': 0.032507174015045166, 'critic_loss': 10.016016206741334, 'actor_loss': 134167.12390625, 'temp': 0.0, 'temp_loss': 0.0, 'time_step': 0.04031641721725464, 'environment': -1209.6452679429296}[0m [36mstep[0m=[35m100[0m
[2m2025-02-05 15:41.32[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/AWAC_20250205154126/model_100.d3[0m


Epoch 2/2: 100%|██████████| 100/100 [00:03<00:00, 25.85it/s, critic_loss=3.3, actor_loss=1.01e+5, temp=0, temp_loss=0]


[2m2025-02-05 15:41.38[0m [[32m[1minfo     [0m] [1mAWAC_20250205154126: epoch=2 step=200[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.009033420085906983, 'time_algorithm_update': 0.029361975193023682, 'critic_loss': 3.252927803993225, 'actor_loss': 100288.04125, 'temp': 0.0, 'temp_loss': 0.0, 'time_step': 0.03847856998443604, 'environment': -699.5642436670706}[0m [36mstep[0m=[35m200[0m
[2m2025-02-05 15:41.38[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/AWAC_20250205154126/model_200.d3[0m
Done


In [9]:
print(type(history))
history


<class 'list'>


[(1,
  {'time_sample_batch': 0.007736170291900634,
   'time_algorithm_update': 0.032507174015045166,
   'critic_loss': 10.016016206741334,
   'actor_loss': 134167.12390625,
   'temp': 0.0,
   'temp_loss': 0.0,
   'time_step': 0.04031641721725464,
   'environment': -1209.6452679429296}),
 (2,
  {'time_sample_batch': 0.009033420085906983,
   'time_algorithm_update': 0.029361975193023682,
   'critic_loss': 3.252927803993225,
   'actor_loss': 100288.04125,
   'temp': 0.0,
   'temp_loss': 0.0,
   'time_step': 0.03847856998443604,
   'environment': -699.5642436670706})]