In [1]:
import numpy as np


def _load_dataset_minari(env_name): # jensk
    import minari
    if env_name == 'PointMaze_UMaze-v3' :
        env_name = 'pointmaze-umaze-v1'
    minari.download_dataset(env_name)
    dataset = minari.load_dataset(env_name, download=True)
    trajectories = dataset._data.get_episodes(dataset.episode_indices)
    states, traj_lens, returns = [], [], []
    if 'pointmaze' in env_name :
        # re-label observation. (achieved_goal, desired_goal) -> observation
        print("re-label observation. (achieved_goal, desired_goal) -> observation")
        for path in trajectories :
            achieved_goal = path['observations']['achieved_goal'][1:]
            desired_goal = path['observations']['desired_goal'][1:]
            observation = np.concatenate([achieved_goal, desired_goal], axis=1)
            path['observations'] = observation

    for path in trajectories:
        states.append(path["observations"])
        traj_lens.append(len(path["observations"]))
        returns.append(path["rewards"].sum())
        # for pointmaze
    traj_lens, returns = np.array(traj_lens), np.array(returns)
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
    num_timesteps = sum(traj_lens)

    print("=" * 50)
    print(f"Starting new experiment: {env_name}")
    print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
    print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
    print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
    print(f"Average length: {np.mean(traj_lens):.2f}, std: {np.std(traj_lens):.2f}")
    print(f"Max length: {np.max(traj_lens):.2f}, min: {np.min(traj_lens):.2f}")
    print("=" * 50)

    sorted_inds = np.argsort(returns)  # lowest to highest
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] < num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]

    trajectories = [trajectories[ii] for ii in sorted_inds]
    return trajectories, state_mean, state_std
    

In [2]:
offline_trajs, state_mean, state_std = _load_dataset_minari('PointMaze_UMaze-v3')

  from .autonotebook import tqdm as notebook_tqdm
[0m
  logger.warn(


re-label observation. (achieved_goal, desired_goal) -> observation
Starting new experiment: pointmaze-umaze-v1
13289 trajectories, 999996 timesteps found
Average return: 1.00, std: 0.02
Max return: 1.00, min: 0.00
Average length: 75.25, std: 45.52
Max length: 193.00, min: 0.00


In [30]:
import torch

class SubTrajectory(torch.utils.data.Dataset):
    def __init__(
        self,
        trajectories,
        sampling_ind,
        transform=None,
    ):

        super(SubTrajectory, self).__init__()
        self.sampling_ind = sampling_ind
        self.trajs = trajectories
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        
        traj = self.trajs[self.sampling_ind[index]]
        if self.transform:
            return self.transform(traj)
        else:
            return traj

    def __len__(self):
        return len(self.sampling_ind)

import random
MAX_EPISODE_LEN = 1000



def discount_cumsum(x, gamma):
    ret = np.zeros_like(x)
    ret[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        ret[t] = x[t] + gamma * ret[t + 1]
    return ret


class TransformSamplingSubTraj01:
    def __init__(
        self,
        max_len,
        state_dim,
        state_mean,
        state_std,
        reward_scale,
        state_range,
    ):
        super().__init__()
        self.max_len = max_len
        self.state_dim = state_dim
        self.state_mean = state_mean
        self.state_std = state_std
        self.reward_scale = reward_scale

        # For some datasets there are actions with values 1.0/-1.0 which is problematic
        # for the SquahsedNormal distribution. The inversed tanh transformation will
        # produce NAN when computing the log-likelihood. We clamp them to be within
        # the user defined action range.
        self.state_range = state_range

    def __call__(self, traj):
        
        si = random.randint(0, traj["rewards"].shape[0] - 1)
        
        
        # get sequences from dataset
        ss = traj["observations"][si : si + self.max_len].reshape(-1, self.state_dim)
        rr = traj["rewards"][si : si + self.max_len].reshape(-1, 1)

        if "terminals" in traj:
            dd = traj["terminals"][si : si + self.max_len]  # .reshape(-1)
        elif "terminations" in traj:
            dd = traj["terminations"][si : si + self.max_len]  # .reshape(-1)
        else :
            dd = traj["dones"][si : si + self.max_len]  # .reshape(-1)

        # get the total length of a trajectory
        tlen = ss.shape[0]


        timesteps = np.arange(si, si + tlen)  # .reshape(-1)
        ordering = np.arange(tlen)
        ordering[timesteps >= MAX_EPISODE_LEN] = -1
        ordering[ordering == -1] = ordering.max()
        timesteps[timesteps >= MAX_EPISODE_LEN] = MAX_EPISODE_LEN - 1  # padding cutoff

        rtg = discount_cumsum(traj["rewards"][si:], gamma=1.0)[: tlen + 1].reshape(
            -1, 1
        )
        if rtg.shape[0] <= tlen:
            rtg = np.concatenate([rtg, np.zeros((1, 1))])

        # padding and state + reward normalization
        state_len = ss.shape[0]
        if tlen != state_len:
            raise ValueError

        ss = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), ss])
        #ss = (ss - self.state_mean) / self.state_std
        # manaul normalization
        manual_normalization = True
        if manual_normalization :
            manual_std = 2
            ep = 1e-7
            ss = ss + ep / manual_std
        # jesnk: do not normalize state?

        rr = np.concatenate([np.zeros((self.max_len - tlen, 1)), rr])
        dd = np.concatenate([np.ones((self.max_len - tlen)) * 2, dd])
        rtg = (
            np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg])
            * self.reward_scale
        )
        
        
        #print(f'{(rtg.shape[0] + self.max_len - tlen)} == {self.max_len})')
        
        timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps])
        ordering = np.concatenate([np.zeros((self.max_len - tlen)), ordering])
        padding_mask = np.concatenate([np.zeros(self.max_len - tlen), np.ones(tlen)])

        ss = torch.from_numpy(ss).to(dtype=torch.float32).clamp(*self.state_range)
        rr = torch.from_numpy(rr).to(dtype=torch.float32)
        dd = torch.from_numpy(dd).to(dtype=torch.long)
        rtg = torch.from_numpy(rtg).to(dtype=torch.float32)
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.long)
        ordering = torch.from_numpy(ordering).to(dtype=torch.long)
        padding_mask = torch.from_numpy(padding_mask)

        return ss, rr, dd, rtg, timesteps, ordering, padding_mask


def sample_trajs(trajectories, sample_size):

    traj_lens = np.array([len(traj["observations"]) for traj in trajectories])
    p_sample = traj_lens / np.sum(traj_lens)

    for idx, traj in enumerate(trajectories) :
        if len(traj['observations']) < 5 :
            # set p to 0
            p_sample[idx] = 0
    # set sum(p_sample) to 1
    p_sample = p_sample / np.sum(p_sample)
    
    
    inds = np.random.choice(
        np.arange(len(trajectories)),
        size=sample_size,
        replace=True,
        p=p_sample,
    )
    return inds

def create_dataloader_01(
    trajectories,
    num_iters,
    batch_size,
    max_len,
    state_dim,
    state_mean,
    state_std,
    reward_scale,
    state_range,
    num_workers=24,
):
    # total number of subt-rajectories you need to sample
    sample_size = batch_size * num_iters
    sampling_ind = sample_trajs(trajectories, sample_size)
    transform = TransformSamplingSubTraj01(
        max_len=max_len,
        state_dim=state_dim,
        state_mean=state_mean,
        state_std=state_std,
        reward_scale=reward_scale,
        state_range=state_range,
    )

    subset = SubTrajectory(trajectories, sampling_ind=sampling_ind, transform=transform)

    return torch.utils.data.DataLoader(
        subset, batch_size=batch_size, num_workers=num_workers, shuffle=False
    )

dataloader = create_dataloader_01(
    trajectories=offline_trajs,
    num_iters=1,
    batch_size=256,
    max_len=5,
    state_dim=4,
    state_mean=[state_mean],
    state_std=state_std,
    reward_scale=0.001,
    state_range= [-2,2]
)


transform = TransformSamplingSubTraj01(
    max_len=5,
    state_dim=4,
    state_mean=state_mean,
    state_std=state_std,
    reward_scale=0.01,
    state_range=[-2,2],
)
#transform_result = transform(offline_trajs[0])


In [34]:
# iterate over data
for _, (ss, rr, dd, rtg, timesteps, ordering, padding_mask) in enumerate(dataloader):
    print(ss.shape)
    print(rr.shape)
    print(dd.shape)
    print(rtg.shape)
    print(timesteps.shape)
    print(ordering.shape)
    print(padding_mask.shape)
    if _ > 10 :
        break
    


torch.Size([256, 5, 4])
torch.Size([256, 5, 1])
torch.Size([256, 5])
torch.Size([256, 6, 1])
torch.Size([256, 5])
torch.Size([256, 5])
torch.Size([256, 5])


In [28]:
# get sequences from dataset
dataloader = create_dataloader_01(
    trajectories=offline_trajs,
    num_iters=1,
    batch_size=256,
    max_len=5,
    state_dim=4,
    state_mean=[state_mean],
    state_std=state_std,
    reward_scale=0.001,
    state_range= [-2,2]
)
for 

<torch.utils.data.dataloader.DataLoader at 0x7fd9ff5f6d90>