In [None]:
import pickle
import numpy as np

# Load your sb3 replay buffer
with open('buffer_with_embedd.pkl', 'rb') as f:
    sb3_buffer = pickle.load(f)

In [None]:
import gym
from stable_baselines3.common.buffers import ReplayBuffer
import d3rlpy
from d3rlpy.dataset import MDPDataset

def to_mdp_dataset(replay_buffer: ReplayBuffer) -> MDPDataset:
    pos = int(replay_buffer.size()/20)
    observations = replay_buffer.observations["rgb"]
    observations = observations.reshape(-1, *observations.shape[2:])
    print(observations.shape)
    actions = replay_buffer.actions.reshape(-1)
    rewards = replay_buffer.rewards.reshape(-1)
    terminals = replay_buffer.dones.reshape(-1)
    timeouts = replay_buffer.timeouts.reshape(-1)
    # if terminals and timeout are True at he same time, it means the terminals is True and timeout is False
    # so we need to set the timeouts to False
    timeouts = np.where(terminals, False, timeouts)



    return MDPDataset(observations=observations,
                      actions=actions,
                      rewards=rewards,
                      terminals=terminals,
                      timeouts=timeouts,
                      transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4))

In [None]:
mdp_dataset  = to_mdp_dataset(sb3_buffer)

In [None]:
from d3rlpy.algos import DQNConfig, SACConfig, IQLConfig

sac = IQLConfig().create(device="cuda:0")
sac.build_with_dataset(mdp_dataset)

In [None]:
from d3rlpy.metrics import TDErrorEvaluator, DiscountedSumOfAdvantageEvaluator

# calculate metrics with training dataset
td_error_evaluator = TDErrorEvaluator(episodes=mdp_dataset.episodes)
discounted_sum_of_advantage_evaluator = DiscountedSumOfAdvantageEvaluator(episodes=mdp_dataset.episodes)

In [None]:
sac.fit(
    mdp_dataset,
    n_steps=10000,
    evaluators={
        'td_error': td_error_evaluator,
        'discounted_sum_of_advantage': discounted_sum_of_advantage_evaluator,
    },
)