In [1]:
import os
import random
import warnings
from dataclasses import dataclass
import numpy as np
import torch
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments
from datasets import load_from_disk
from tqdm.auto import trange, tqdm
import datasets


In [2]:
from datasets import load_dataset
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")

In [3]:
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

In [4]:
return_tensors: str = "pt"
max_len: int = 20 #subsets of the episode we use for training
state_dim: int = 17  # size of state space
act_dim: int = 6  # size of action space
max_ep_len: int = 1000 # max episode length in the dataset
scale: float = 1000.0  # normalization of rewards/returns
state_mean: np.array = None  # to store state means
state_std: np.array = None  # to store state stds
p_sample: np.array = None  # a distribution to take account trajectory lengths
n_traj: int = 0 # to store the number of trajectories in the dataset


In [6]:
dataset = dataset["train"]

In [7]:
act_dim = len(dataset[0]["actions"][0])
state_dim = len(dataset[0]["observations"][0])
dataset = dataset

In [8]:
states = []
traj_lens = []
for obs in dataset["observations"]:
    states.extend(obs)
    traj_lens.append(len(obs))
n_traj = len(traj_lens)
states = np.vstack(states)
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
traj_lens = np.array(traj_lens)
p_sample = traj_lens / sum(traj_lens)

In [28]:
batch_size = 1

In [29]:
n_traj

1000

In [30]:
batch_inds = np.random.choice(np.arange(n_traj),size = batch_size, replace =True, p = p_sample)

In [31]:
batch_inds

array([400])

In [63]:
s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

In [64]:
feature = dataset[400]

In [65]:
si = random.randint(0, len(feature["rewards"]) - 1)

In [66]:
si = 990

In [67]:
obs = np.array(feature["observations"][si : si + max_len]).reshape(1, -1, state_dim)

In [68]:
act = np.array(feature["actions"][si : si + max_len]).reshape(1, -1, act_dim)
rew = np.array(feature["rewards"][si : si + max_len]).reshape(1, -1, 1)

In [69]:
done = np.array(feature["dones"][si : si + max_len]).reshape(1, -1)

In [70]:
s.append(obs)
a.append(act)
r.append(rew)
d.append(done)

In [71]:
timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))

In [73]:
timesteps

[array([[990, 991, 992, 993, 994, 995, 996, 997, 998, 999]])]

In [74]:
timesteps[-1] >=995

array([[False, False, False, False, False,  True,  True,  True,  True,
         True]])

In [75]:
max_ep_len = 995

In [76]:
timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1

In [77]:
timesteps

[array([[990, 991, 992, 993, 994, 994, 994, 994, 994, 994]])]

In [80]:
x = np.array(feature["rewards"][si:])

In [81]:
discount_cumsum = np.zeros_like(x)

In [82]:
discount_cumsum

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [83]:
discount_cumsum[-1] = x[-1]

In [84]:
discount_cumsum

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , 10.98636818])

In [86]:
for t in reversed(range(x.shape[0] - 1)):
    print(t)

8
7
6
5
4
3
2
1
0


In [88]:
x.shape[0]

10

In [89]:
x[8]

11.682415962219238

In [92]:
for t in reversed(range(x.shape[0] - 1)):
    discount_cumsum[t] = x[t] + 1 * discount_cumsum[t + 1]

In [97]:
discount_cumsum[:s[-1].shape[1]].reshape(1,-1,1)

array([[[115.7624464 ],
        [103.89986229],
        [ 92.61529922],
        [ 80.94272327],
        [ 69.47979927],
        [ 58.13008118],
        [ 45.9987936 ],
        [ 34.39903355],
        [ 22.66878414],
        [ 10.98636818]]])

In [98]:
tlen = s[-1].shape[1]

In [100]:
max_len

20

In [102]:
np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)

array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.

In [104]:
1e-2

0.01