In [41]:
# import
import os
import sys
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch
import gym
from tqdm import tqdm

from model import DecisionTransformer
from utils import D4RLTrajectoryDataset


In [2]:
# set hyperparameter
env_name = 'halfcheetah'
dataset = 'medium'

if env_name == 'hopper':
    env = gym.make('Hopper-v2')
    max_ep_len = 1000

elif env_name == 'halfcheetah':
    env = gym.make('HalfCheetah-v2')
    max_ep_len = 1000

elif env_name == 'walker2d':
    env = gym.make('Walker2d-v2')
    max_ep_len = 1000


if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')


AUG_DATA_PATH = f'data/augmented/{env_name}-{dataset}-v2.npz'


  logger.warn(


In [3]:
# check dim

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

print("state dim: ", state_dim)
print("action dim: ", act_dim)

state dim:  17
action dim:  6


In [4]:
# load best model
eval_chk_pt_dir = "./best_model/"
eval_chk_pt_name = "dt_halfcheetah_model_24-05-16-00-24-38_best.pt"

batch_size = 32
embed_dim = 128
activation = 'relu'
drop_out = 0.1
k = 31
n_blocks = 3
n_heads = 1 # transformer head

best_model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            n_blocks=n_blocks,
            h_dim=embed_dim,
            context_len=k,
            n_heads=n_heads,
            drop_p=drop_out,
        ).to(DEVICE)


chk_pt_path = os.path.join(eval_chk_pt_dir, eval_chk_pt_name)

# load checkpoint
best_model.load_state_dict(torch.load(chk_pt_path, map_location=DEVICE))
best_model.eval()

DecisionTransformer(
  (transformer): Sequential(
    (0): Block(
      (attention): MaskedCausalAttention(
        (q_net): Linear(in_features=128, out_features=128, bias=True)
        (k_net): Linear(in_features=128, out_features=128, bias=True)
        (v_net): Linear(in_features=128, out_features=128, bias=True)
        (proj_net): Linear(in_features=128, out_features=128, bias=True)
        (att_drop): Dropout(p=0.1, inplace=False)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): Block(
      (attention): MaskedCausalAttention(
        (q_net): Linear(in_features=128, out_features=12

In [9]:
# load augmented dataset
aug_dataset_sample = np.load(AUG_DATA_PATH, allow_pickle=True)
aug_dataset_sample = aug_dataset_sample['data']

In [24]:
print("# of episode: ", len(aug_dataset_sample))
print("content len: ", aug_dataset_sample[0]['observations'].shape[0])

# of episode:  161291
content len:  31


In [10]:
# check augmented dataset
# dataset[episode][feature][timestamp]
# episode <= 161290, timestep <= 30
ori = aug_dataset_sample[161290]['original_observations']
state = aug_dataset_sample[161290]['observations']
next_state = aug_dataset_sample[161290]['next_observations']
action = aug_dataset_sample[161290]['actions']
reward = aug_dataset_sample[161290]['rewards']
timestep = aug_dataset_sample[100]['observations'].shape[0]

print("state: ", state)
print("next_state: ", next_state)
print("action: ", action)
print("reward: ", reward)
print("timestep: ", timestep)
print("original: ", ori)

state:  [[-2.86957510e-02  7.06104562e-04  1.65495008e-01  6.58564866e-02
  -2.41920754e-01 -1.05059385e+00 -4.92760658e-01 -4.55361217e-01
   7.98386621e+00 -5.03951013e-01  5.16349506e+00 -2.13617635e+00
  -8.54519558e+00 -4.93509579e+00 -1.18504820e+01 -7.23576450e+00
   3.13616824e+00]
 [-2.65657008e-02  5.47008552e-02 -9.29864198e-02 -1.49895459e-01
  -4.39769804e-01 -3.73108029e-01 -2.85037011e-01  5.32570034e-02
   8.58654881e+00 -9.99051332e-02 -1.17479369e-01 -7.13248158e+00
  -2.29523802e+00 -3.03741619e-02  2.15435925e+01  1.26372910e+01
   1.30124493e+01]
 [-5.69641255e-02  4.53174114e-02 -4.47341532e-01 -4.43460852e-01
  -4.14766133e-01  2.90404439e-01  2.67008424e-01  2.12057412e-01
   8.15756989e+00 -5.08947372e-01 -1.50061691e+00 -6.03225327e+00
  -6.41128826e+00  1.50939032e-01  1.02164116e+01  5.09479523e+00
  -2.73729491e+00]
 [-9.76325274e-02 -1.00335684e-02 -5.42420268e-01 -7.66327202e-01
  -4.18085724e-01  7.75999188e-01  6.94274157e-02 -9.07363221e-02
   8.062833

In [25]:
aug_dataset = D4RLTrajectoryDataset(aug_dataset_sample, 31, not_path=True)

aug_data_loader = DataLoader(aug_dataset,
						batch_size=1,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
						
aug_data_iter = iter(aug_data_loader)

timesteps, states, next_states, actions, rewards, traj_mask = next(aug_data_iter)

timesteps = timesteps.to(DEVICE)	# B x T
states = states.to(DEVICE)			# B x T x state_dim
next_states = next_states.to(DEVICE) # B X T X state_dim
actions = actions.to(DEVICE)		# B x T x act_dim
rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1


In [26]:
print(timesteps.shape)
print(states.shape)
print(actions.shape)
print(rewards.shape)

torch.Size([1, 31])
torch.Size([1, 31, 17])
torch.Size([1, 31, 6])
torch.Size([1, 31, 1])


In [27]:
next_state_preds, rewards_preds = best_model.forward(
                                                rewards=rewards,
                                                timesteps=timesteps,
                                                states=states,
                                                actions=actions,
                                            )

In [33]:
next_state_preds

tensor([[[-5.4316e-01,  1.4823e-01, -1.2098e+00, -1.2899e+00, -4.8106e-01,
           1.6355e+00,  4.6842e-01,  4.0672e-01,  1.1142e+00, -1.0664e+00,
          -2.4927e-01, -6.6667e-02, -6.4739e-01, -2.0439e-02,  6.9192e-01,
          -1.3320e+00, -4.4519e-01],
         [-9.4849e-01, -1.3340e-01, -7.0845e-01,  5.5197e-01, -4.9704e-01,
           1.4145e+00,  2.4859e-01, -2.0949e-01,  1.4056e+00, -7.7549e-01,
          -1.2360e+00,  1.0170e+00,  2.4697e+00,  1.9510e-02, -3.0441e-01,
          -7.2813e-01, -8.5487e-01],
         [-3.6512e-01, -3.4905e-01,  4.9690e-01,  1.7375e+00, -1.5669e-01,
          -6.3781e-01, -5.2263e-02, -9.3233e-01,  1.4282e+00,  8.8194e-01,
           1.1068e+00,  1.7353e+00,  1.1013e+00,  1.9105e+00, -1.9285e+00,
          -3.0551e-01, -3.5011e-01],
         [-2.6033e-01, -1.0182e-02,  1.5834e+00,  1.4596e+00,  6.2074e-01,
          -1.8418e+00, -9.3794e-01, -1.2843e+00,  9.0662e-01,  3.8654e-01,
           1.3210e+00, -3.3410e-01, -5.6223e-01, -1.8165e-01, -4

In [29]:
print("next_state_shape: ", next_state_preds.shape)
print("rewards_shape: ", rewards_preds.shape)

# next_state_preds, rewards_preds

next_state_shape:  torch.Size([1, 31, 17])
rewards_shape:  torch.Size([1, 31, 1])


In [42]:
# filtering
Percentage = 0.1 # 0.1 ~ 1

def filtering_transformer(augmented_dataset_sample, model, Percentage=Percentage):
    
    filtered_dataset = pd.DataFrame(columns = ['states', 'next_states', 'actions', 'rewards', 'timestep', 'traj_mask', 'mse'])
    
    states_list, next_states_list, actions_list, rewards_list, timestep_list, traj_mask_list, mse_list = [], [], [], [], [], [], []
    
    aug_dataset = D4RLTrajectoryDataset(augmented_dataset_sample, 1, not_path=True)

    aug_data_loader = DataLoader(aug_dataset,
                            batch_size=1,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=True)
                            
    for timesteps, states, next_states, actions, rewards, traj_mask in tqdm(aug_data_loader):
        
        states_list.append(states)
        next_states_list.append(next_states)
        actions_list.append(actions)
        rewards_list.append(rewards)
        timestep_list.append(timesteps)
        traj_mask_list.append(traj_mask)
        
        timesteps = timesteps.to(DEVICE)	# B x T
        states = states.to(DEVICE)			# B x T x state_dim
        next_states = next_states.to(DEVICE) # B X T X state_dim
        actions = actions.to(DEVICE)		# B x T x act_dim
        rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
    
        pred_next_states = torch.clone(next_states).detach().to(DEVICE)
        pred_rewards = torch.clone(rewards).detach().to(DEVICE)
    
        real_next_state, real_rewards = model.forward(
                                                        rewards=rewards,
                                                        timesteps=timesteps,
                                                        states=states,
                                                        actions=actions,
                                                    )
        pred_next_states = pred_next_states.view(-1, state_dim)[traj_mask.view(-1,) > 0]
        real_next_state = real_next_state.view(-1, state_dim)[traj_mask.view(-1,) > 0]

        pred_rewards = pred_rewards.view(-1, 1)[traj_mask.view(-1,) > 0]
        real_rewards = real_rewards.view(-1, 1)[traj_mask.view(-1,) > 0]

        state_loss = F.mse_loss(pred_next_states, real_next_state, reduction='mean')
        reward_loss = F.mse_loss(pred_rewards, real_rewards, reduction='mean')
        total_loss = (state_loss + reward_loss)/2
        mse_list.append(total_loss)
                                                    
    filtered_dataset['states'] = states_list
    filtered_dataset['next_states'] = next_states_list
    filtered_dataset['actions'] = actions_list
    filtered_dataset['rewards'] = rewards_list
    filtered_dataset['timestep'] = timestep_list
    filtered_dataset['traj_mask'] = traj_mask_list
    filtered_dataset['mse'] = mse_list
    
    filtered_dataset.sort_values(by='mse', ascending=True)
    
    print("# of filtered dataset: ", len(filtered_dataset))
    keep_rows = int(len(filtered_dataset) * (1-Percentage))
    
    filtered_dataset = filtered_dataset.head(keep_rows)
    
    return filtered_dataset
    
filtered_dataset = filtering_transformer(aug_dataset_sample, best_model, Percentage=Percentage)

 59%|█████▉    | 94901/161291 [12:29<11:15, 98.26it/s]  

In [None]:
# save filtered dataset