In [1]:
# set your start path
import sys
import os

os.chdir(r'C:/Develop/offlineRL-with-diffusion')

In [2]:
# import
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch
import gym
from tqdm import tqdm

from transformer.gpt_transformer.src.utils import D4RLTrajectoryDataset
from transformer.gpt_transformer.src.model import DecisionTransformer


In [3]:
# set hyperparameter
env_name = 'halfcheetah'
dataset = 'medium'

if env_name == 'hopper':
    env = gym.make('Hopper-v2')

elif env_name == 'halfcheetah':
    env = gym.make('HalfCheetah-v2')

elif env_name == 'walker2d':
    env = gym.make('Walker2d-v2')


if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')


  logger.warn(


In [4]:
# check dim

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

print("state dim: ", state_dim)
print("action dim: ", act_dim)

state dim:  17
action dim:  6


In [5]:
# load best model

BEST_MODEL_PATH = f"transformer/gpt_transformer/src/best_model/{env_name}-{dataset}_model_best.pt"

batch_size = 128
embed_dim = 128
activation = 'relu'
drop_out = 0.1
k = 31
n_blocks = 3
n_heads = 1 # transformer head

AUG_DATA_PATH = f'transformer/gpt_transformer/src/data/augmented/{env_name}-{dataset}-v2.npz'
FILTERED_DATA_PATH = f'transformer/gpt_transformer/src/data/filtered/{env_name}-{dataset}-v2.npz'

best_model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            n_blocks=n_blocks,
            h_dim=embed_dim,
            context_len=k,
            n_heads=n_heads,
            drop_p=drop_out,
        ).to(DEVICE)


# load checkpoint
best_model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
best_model.eval()



DecisionTransformer(
  (transformer): Sequential(
    (0): Block(
      (attention): MaskedCausalAttention(
        (q_net): Linear(in_features=128, out_features=128, bias=True)
        (k_net): Linear(in_features=128, out_features=128, bias=True)
        (v_net): Linear(in_features=128, out_features=128, bias=True)
        (proj_net): Linear(in_features=128, out_features=128, bias=True)
        (att_drop): Dropout(p=0.1, inplace=False)
        (proj_drop): Dropout(p=0.1, inplace=False)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): Block(
      (attention): MaskedCausalAttention(
        (q_net): Linear(in_features=128, out_features=12

In [6]:
# load augmented dataset
aug_dataset_sample = np.load(AUG_DATA_PATH, allow_pickle=True)
aug_dataset_sample = aug_dataset_sample['data']

In [7]:
print(aug_dataset_sample[0])

{'observations': array([[-9.78982896e-02,  1.13477379e-01,  1.98929846e-01,
         6.73338473e-02, -2.23132133e-01, -9.95479763e-01,
        -4.97307360e-01, -5.16024351e-01,  9.02984238e+00,
         1.12497166e-01,  4.16288280e+00, -3.49108428e-01,
        -1.01194935e+01, -5.76310039e-01, -8.60520935e+00,
        -7.22706127e+00,  1.51465490e-01],
       [-9.34970081e-02,  1.04090542e-01, -1.35129943e-01,
         2.85682887e-01, -3.97425026e-01, -3.31844419e-01,
        -2.32354969e-01, -4.77415584e-02,  9.32954693e+00,
        -1.71609208e-01, -1.79579675e+00, -8.79514313e+00,
         8.46916866e+00, -4.05881453e+00,  2.01686878e+01,
         1.33829222e+01,  1.36221838e+01],
       [-7.25239292e-02,  3.64939049e-02, -4.39952761e-01,
        -1.81274801e-01, -4.14618701e-01,  4.15624201e-01,
        -1.49758965e-01,  2.10397542e-01,  8.33133316e+00,
         4.72206712e-01, -1.19015479e+00, -3.44378471e+00,
        -1.55777407e+01,  7.04492033e-01,  1.08165216e+01,
         2.6

In [8]:
print("shape of dataset: ", aug_dataset_sample.shape)
print("# of episode: ", len(aug_dataset_sample))
print("content len: ", aug_dataset_sample[0]['observations'].shape[0])

shape of dataset:  (161291,)
# of episode:  161291
content len:  31


In [9]:
cnt = 0

for aug in aug_dataset_sample:
    print(aug['observations'].shape)
    
    if cnt >= 5:
        break
    
    cnt+=1

(31, 17)
(31, 17)
(31, 17)
(31, 17)
(31, 17)
(31, 17)


In [10]:
# calculate mean, std

aug_states, aug_next_states, aug_rewards = [], [], []
for traj in aug_dataset_sample:
    aug_states.append(traj['observations'])
    aug_next_states.append(traj['next_observations'])
    aug_rewards.append(traj['rewards'])
    
# used for input, output normalization
aug_states = np.concatenate(aug_states, axis=0)
# print("state shape: ", states.shape)
aug_state_mean, aug_state_std = np.mean(aug_states, axis=0), np.std(aug_states, axis=0)

aug_next_states = np.concatenate(aug_next_states, axis=0)
aug_next_state_mean, aug_next_state_std = np.mean(aug_next_states, axis=0), np.std(aug_next_states, axis=0)

aug_rewards = np.concatenate(aug_rewards, axis=0)
aug_reward_mean, aug_reward_std = np.mean(aug_rewards, axis=0), np.std(aug_rewards, axis=0)

In [11]:
# check augmented dataset
# dataset[episode][feature][timesteps]
# episode <= 161290, timestep <= 30
ori = aug_dataset_sample[161290]['original_observations']
state = aug_dataset_sample[161290]['observations']
next_state = aug_dataset_sample[161290]['next_observations']
action = aug_dataset_sample[161290]['actions']
reward = aug_dataset_sample[161290]['rewards']
timestep = aug_dataset_sample[100]['observations'].shape[0]

print("state: ", state)
print("next_state: ", next_state)
print("action: ", action)
print("reward: ", reward)
print("timestep: ", timestep)
print("original: ", ori)

state:  [[-2.86957510e-02  7.06104562e-04  1.65495008e-01  6.58564866e-02
  -2.41920754e-01 -1.05059385e+00 -4.92760658e-01 -4.55361217e-01
   7.98386621e+00 -5.03951013e-01  5.16349506e+00 -2.13617635e+00
  -8.54519558e+00 -4.93509579e+00 -1.18504820e+01 -7.23576450e+00
   3.13616824e+00]
 [-2.65657008e-02  5.47008552e-02 -9.29864198e-02 -1.49895459e-01
  -4.39769804e-01 -3.73108029e-01 -2.85037011e-01  5.32570034e-02
   8.58654881e+00 -9.99051332e-02 -1.17479369e-01 -7.13248158e+00
  -2.29523802e+00 -3.03741619e-02  2.15435925e+01  1.26372910e+01
   1.30124493e+01]
 [-5.69641255e-02  4.53174114e-02 -4.47341532e-01 -4.43460852e-01
  -4.14766133e-01  2.90404439e-01  2.67008424e-01  2.12057412e-01
   8.15756989e+00 -5.08947372e-01 -1.50061691e+00 -6.03225327e+00
  -6.41128826e+00  1.50939032e-01  1.02164116e+01  5.09479523e+00
  -2.73729491e+00]
 [-9.76325274e-02 -1.00335684e-02 -5.42420268e-01 -7.66327202e-01
  -4.18085724e-01  7.75999188e-01  6.94274157e-02 -9.07363221e-02
   8.062833

In [12]:
aug_dataset = D4RLTrajectoryDataset(aug_dataset_sample, 1, not_path=True)

aug_data_loader = DataLoader(aug_dataset,
						batch_size=1,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
						
aug_data_iter = iter(aug_data_loader)

timesteps, states, next_states, actions, rewards, traj_mask = next(aug_data_iter)

timesteps = timesteps.to(DEVICE)	# B x T
states = states.to(DEVICE)			# B x T x state_dim
next_states = next_states.to(DEVICE) # B X T X state_dim
actions = actions.to(DEVICE)		# B x T x act_dim
rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1


In [13]:
print(timesteps.shape)
print(states.shape)
print(actions.shape)
print(rewards.shape)

torch.Size([1, 1])
torch.Size([1, 1, 17])
torch.Size([1, 1, 6])
torch.Size([1, 1, 1])


In [14]:
next_state_preds, rewards_preds = best_model.forward(
                                                rewards=rewards,
                                                timesteps=timesteps,
                                                states=states,
                                                actions=actions,
                                            )

In [15]:
next_state_preds

tensor([[[-0.0915,  0.5200,  0.7236, -0.2849,  0.3945, -0.8506,  0.2536,
          -0.6589, -0.1334, -0.2930, -0.2335, -1.0399, -0.2821,  0.4592,
          -0.7844, -0.4881,  0.1434]]], grad_fn=<ViewBackward0>)

In [16]:
print("next_state_shape: ", next_state_preds.shape)
print("rewards_shape: ", rewards_preds.shape)

# next_state_preds, rewards_preds

next_state_shape:  torch.Size([1, 1, 17])
rewards_shape:  torch.Size([1, 1, 1])


In [17]:
# filtering
Percentage = 0.1 # 0.1 ~ 1

def filtering_transformer(augmented_dataset_sample, model, Percentage=Percentage):
    
    # temp = 0
    
    filtered_dataset = pd.DataFrame(columns = ['states', 'next_states', 'actions', 'rewards', 'timesteps', 'traj_mask', 'mse'])
    
    states_list, next_states_list, actions_list, rewards_list, timesteps_list, traj_mask_list, mse_list = [], [], [], [], [], [], []
    
    aug_dataset = D4RLTrajectoryDataset(augmented_dataset_sample, k, not_path=True)

    aug_data_loader = DataLoader(aug_dataset,
                            batch_size=1,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=True)
                            
    for timesteps, states, next_states, actions, rewards, traj_mask in tqdm(aug_data_loader):
        
        states_list.append(np.array(states.reshape(k, state_dim)))
        next_states_list.append(np.array(next_states.reshape(k, state_dim)))
        actions_list.append(np.array(actions.reshape(k, act_dim)))
        rewards_list.append(np.array(rewards.reshape(-1,)))
        timesteps_list.append(np.array(np.squeeze(timesteps, axis=0)))
        traj_mask_list.append(np.array(np.squeeze(traj_mask, axis=0)))

        # normalization
        states = (states - aug_state_mean) / aug_state_std
        next_states = (next_states - aug_next_state_mean) / aug_next_state_std
        rewards = (rewards - aug_reward_mean) / aug_reward_std
        
        timesteps = timesteps.to(DEVICE)	# B x T
        states = states.to(DEVICE)			# B x T x state_dim
        next_states = next_states.to(DEVICE) # B X T X state_dim
        actions = actions.to(DEVICE)		# B x T x act_dim
        rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
        traj_mask = traj_mask.to(DEVICE)	# B x T
    
        pred_next_states = torch.clone(next_states).detach().to(DEVICE)
        pred_rewards = torch.clone(rewards).detach().to(DEVICE)
    
        real_next_state, real_rewards = model.forward(
                                                        rewards=rewards,
                                                        timesteps=timesteps,
                                                        states=states,
                                                        actions=actions,
                                                    )
        pred_next_states = pred_next_states.view(-1, state_dim)[traj_mask.view(-1,) > 0]
        real_next_state = real_next_state.view(-1, state_dim)[traj_mask.view(-1,) > 0]

        pred_rewards = pred_rewards.view(-1, 1)[traj_mask.view(-1,) > 0]
        real_rewards = real_rewards.view(-1, 1)[traj_mask.view(-1,) > 0]

        state_loss = F.mse_loss(pred_next_states, real_next_state, reduction='mean')
        reward_loss = F.mse_loss(pred_rewards, real_rewards, reduction='mean')
        
        total_loss = state_loss.add(reward_loss)
        total_loss = torch.mean(total_loss)
        mse_list.append(total_loss.detach().cpu().item())
        
        # temp += 1
        
        # if temp == 10:
        #     break
                                                    
    filtered_dataset['states'] = states_list
    filtered_dataset['next_states'] = next_states_list
    filtered_dataset['actions'] = actions_list
    filtered_dataset['rewards'] = rewards_list
    filtered_dataset['timesteps'] = timesteps_list
    filtered_dataset['traj_mask'] = traj_mask_list
    filtered_dataset['mse'] = mse_list
    
    filtered_dataset.sort_values(by='mse', ascending=True, inplace=True)
    
    print("# of augmented dataset: ", len(filtered_dataset))
    keep_rows = int(len(filtered_dataset) * (1-Percentage))
    
    filtered_dataset = filtered_dataset.head(keep_rows)
    filtered_dataset = filtered_dataset.sample(frac=1).reset_index(drop=True)
    print("# of filtered dataset: ", len(filtered_dataset))
    
    # dataframe to numpy array with dict
    np_filtered_dataset = []
    
    for i in range(len(filtered_dataset)):
        np_filtered_dataset.append({'observations': np.array(filtered_dataset['states'][i]), 
                                    'next_observations': np.array(filtered_dataset['next_states'][i]),
                                    'actions': np.array(filtered_dataset['actions'][i]),
                                    'rewards': np.array(filtered_dataset['rewards'][i]),
                                    'timesteps': np.array(filtered_dataset['timesteps'][i]),
                                    'traj_mask': np.array(filtered_dataset['traj_mask'][i]),
                                    'mse': np.array(filtered_dataset['mse'][i]),
                                    })
        
    
    return np_filtered_dataset
    


In [18]:
filtered_dataset = filtering_transformer(aug_dataset_sample, best_model, Percentage=Percentage)

  0%|          | 9/161291 [00:00<35:39, 75.38it/s]

# of augmented dataset:  10
# of filtered dataset:  9





In [19]:
filtered_dataset[0]['observations'].shape

(31, 17)

In [20]:
# save filtered dataset -> .npz

temp_array = np.array([1,2,])

np.savez(FILTERED_DATA_PATH, data=filtered_dataset, config=temp_array)

In [21]:
# load filtered dataset

filtered_dataset_sample = np.load(FILTERED_DATA_PATH, allow_pickle=True)
filtered_dataset_sample = filtered_dataset_sample['data']

In [22]:
# check filtered dataset

# filtered_data[epi][feature][timestep]

print("# of dataset: ", len(filtered_dataset_sample))
print("state_shape: ", filtered_dataset_sample[0]['observations'].shape)
print("1 episode: ", filtered_dataset_sample[0])
print("state: ", filtered_dataset_sample[0]['observations'])

# of dataset:  9
state_shape:  (31, 17)
1 episode:  {'observations': array([[-8.22658986e-02,  4.86920588e-02,  1.78905189e-01,
         1.09770298e-01, -2.28323162e-01, -1.03105044e+00,
        -4.84628797e-01, -4.56260383e-01,  8.58324718e+00,
        -4.71299253e-02,  5.03728199e+00, -3.21572614e+00,
        -1.07802820e+01, -2.16359925e+00, -1.53369741e+01,
        -6.35729742e+00, -8.40144217e-01],
       [-6.57200515e-02,  5.67822717e-02, -1.28969565e-01,
         7.49948621e-03, -4.21711087e-01, -3.50266218e-01,
        -2.71108925e-01, -7.59878606e-02,  9.31469345e+00,
         9.97876078e-02, -1.46692729e+00, -7.93303347e+00,
         1.29986060e+00, -3.97610760e+00,  2.14901505e+01,
         1.27756901e+01,  1.20369215e+01],
       [-8.49063993e-02,  2.20313072e-02, -4.91307825e-01,
        -3.10196042e-01, -4.14100707e-01,  3.52463543e-01,
        -9.85431671e-03,  3.45665574e-01,  8.69810104e+00,
        -3.92837554e-01, -1.06338954e+00, -5.64835930e+00,
        -8.58873653