In [1]:
# import library
import collections
import csv
import os
import pickle
from datetime import datetime

import d4rl
import gym
import numpy as np
import pyrootutils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

path = pyrootutils.find_root(search_from=os.path.abspath(''), indicator=".project-root")
pyrootutils.set_root(path = path,
                     project_root_env_var = True,
                     dotenv = True,
                     pythonpath = True)
                     
PATH = str(path).replace("\\","/")

from transformer.gpt_transformer.src.model import DecisionTransformer
from transformer.gpt_transformer.src.utils import (D4RLTrajectoryDataset,
                                                   check_batch)

No module named 'flow'
No module named 'carla'


In [2]:
# test mujoco, d4rl

!python ./check/mujoco_test.py

mujoco-py check passed
d4rl check passed


No module named 'flow'
No module named 'carla'
pybullet build time: Apr 30 2024 12:01:25
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [3]:
# data download
# if you downloaded, don't re-start.

!python ../src/data/download_d4rl_datasets.py


N:  1000000
env name:  halfcheetah-medium-v2
original n of traj:  1000
train n of traj:  800
val n of traj:  200
-----------------------------------------------------
N:  202000
env name:  halfcheetah-medium-replay-v2
original n of traj:  202
train n of traj:  161
val n of traj:  41
-----------------------------------------------------
N:  2000000
env name:  halfcheetah-medium-expert-v2
original n of traj:  2000
train n of traj:  1600
val n of traj:  400
-----------------------------------------------------
N:  1000000
env name:  hopper-medium-v2
original n of traj:  2186
train n of traj:  1748
val n of traj:  438
-----------------------------------------------------
N:  402000
env name:  hopper-medium-replay-v2
original n of traj:  2041
train n of traj:  1632
val n of traj:  409
-----------------------------------------------------
N:  1999906
env name:  hopper-medium-expert-v2
original n of traj:  3213
train n of traj:  2570
val n of traj:  643
---------------------------------------

No module named 'flow'
No module named 'carla'
pybullet build time: Apr 30 2024 12:01:25
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")

load datafile:   0%|          | 0/21 [00:00<?, ?it/s]
load datafile:   5%|▍         | 1/21 [00:00<00:04,  4.85it/s]
load datafile:  14%|█▍        | 3/21 [00:00<00:04,  4.07it/s]
load datafile:  19%|█▉        | 4/21 [00:01<00:05,  3.03it/s]
load datafile:  81%|████████  | 17/21 [00:01<00:00, 11.41it/s]
load datafile:  86%|████████▌ | 18/21 [00:02<00:00,  7.33it/s]
load datafile: 100%|██████████| 21/21 [00:02<00:00,  8.23it/s]

load datafile:   0%|          | 0/11 [00:00<?, ?it/s]
load datafile:  27%|██▋       | 3/11 [00:00<00:00, 16.86it/s]
load datafile:  45%|████▌     | 5/11 [00:00<00:00, 17.34it/s]
load datafile:  64%|██████▎   | 7/11 [00:00<00:00, 16.41it/s]
load datafile:  82%|████████▏ | 9/11 [00:00<00:00, 12.15it/s]
load datafile: 100%|██████████| 11/11 [00:00<00:00, 16.66it/s]

load datafile:   0%|          | 0/9 [00:0

In [4]:
# parameter setting

env_name = 'halfcheetah'
dataset = 'medium-replay'

if env_name == 'hopper':
    env = gym.make('Hopper-v2')

elif env_name == 'halfcheetah':
    env = gym.make('HalfCheetah-v2')

elif env_name == 'walker2d':
    env = gym.make('Walker2d-v2')

TRAIN_DATA_PATH = f'{PATH}/transformer/gpt_transformer/src/data/train/{env_name}-{dataset}-v2.pkl'
VAL_DATA_PATH = f'{PATH}/transformer/gpt_transformer/src/data/val/{env_name}-{dataset}-v2.pkl'
ORIGINAL_DATA_PATH = f'{PATH}/transformer/gpt_transformer/src/data/original/{env_name}-{dataset}-v2.pkl'

LOG_PATH = f"{PATH}/transformer/gpt_transformer/src/log/"
BEST_MODEL_PATH = f"{PATH}/transformer/gpt_transformer/src/best_model/"

if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')

  logger.warn(


In [5]:
# env dataset check
check_env = gym.make('halfcheetah-medium-replay-v2')
check_dataset = check_env.get_dataset()

# print(dataset['observations'][1]) # trajectory 단위로 뽑힘.


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
load datafile: 100%|██████████| 11/11 [00:00<00:00, 20.29it/s]


In [6]:
# print("overall len: ", dataset.shape)

In [7]:
print("state shape: ", check_dataset['observations'].shape)
print("action shape: ", check_dataset['actions'].shape)
print("reward shape: ", check_dataset['rewards'].shape)
print("N: ", check_dataset['rewards'].shape[0])
print("train_size: ", int(0.8 * check_dataset['rewards'].shape[0]))


state shape:  (202000, 17)
action shape:  (202000, 6)
reward shape:  (202000,)
N:  202000
train_size:  161600


In [8]:
# data check
# check trajectories

with open(ORIGINAL_DATA_PATH, 'rb') as f:
    trajectories = pickle.load(f)
n=0
max_rewards_list = []
for traj in trajectories:
    # print(f"{n+1}번째 trajectory")
    # print("traj: ", traj)
    print("state: ", traj['observations'], "\n")
    # print("action: ", traj['actions'], "\n")
    # print("next_state: ", traj['next_observations'], "\n")
    # print("reward: ", traj['rewards'], "\n")
    # print("max_rewards: ", max(traj['rewards']))
    # max_rewards_list.append(max(traj['rewards']))
    # print("")
    n+=1
    
# print(max(max_rewards_list))

    if n==3:
        break

state:  [[-4.3362133e-02 -2.6833022e-03  6.5432638e-02 ...  1.0888433e-01
   8.8169552e-02  6.3931167e-02]
 [-3.7674177e-02  1.5365051e-02  2.1673881e-01 ... -8.2050276e+00
   2.6932850e+00 -4.4459348e+00]
 [-8.1146188e-02  1.5728043e-02  2.3461881e-01 ... -1.2342579e+00
  -4.2705555e+00 -3.4320550e+00]
 ...
 [-5.6956607e-01  3.2922680e+00 -3.1815246e-01 ... -3.0207298e+00
   2.2088470e+00 -4.5268812e+00]
 [-5.7592243e-01  3.3024230e+00 -2.3679866e-01 ...  9.0879889e+00
  -1.4607348e+00 -2.9853027e+00]
 [-5.6932127e-01  3.2941539e+00 -2.2014110e-01 ... -6.1263404e+00
   3.7058628e+00  1.0001764e+01]] 

state:  [[ 6.11328473e-03 -8.39964487e-03  8.37445110e-02 ...  1.55069038e-01
  -8.99704620e-02 -3.59976701e-02]
 [-3.13807465e-03  5.21850809e-02 -3.83801684e-02 ... -4.89395149e-02
  -6.29623222e+00  3.82143348e-01]
 [-1.61008500e-02  1.11764394e-01 -2.30730549e-02 ... -5.31878829e-01
   5.92412138e+00  1.48380601e+00]
 ...
 [-5.94031950e-03 -1.12105735e-01  4.97356743e-01 ...  5.75514

In [9]:
# check train data shape
with open(TRAIN_DATA_PATH, 'rb') as f:
    train_trajectories = pickle.load(f)

print("length: ", len(train_trajectories)*len(train_trajectories[0]['observations']))
print("n of epi: ", len(train_trajectories))
print("n of traj in one epi: ", len(train_trajectories[0]['observations']))
# print("train state shape: ", train_trajectories['observations'].shape)
# print("train action shape: ", train_trajectories['actions'].shape)
# print("train reward shape: ", train_trajectories['rewards'].shape)


length:  161000
n of epi:  161
n of traj in one epi:  1000


In [10]:
# compare original and train -> check shuffle
print("ori:", trajectories[0]['observations'])
print("train:", train_trajectories[0]['observations'])


ori: [[-4.3362133e-02 -2.6833022e-03  6.5432638e-02 ...  1.0888433e-01
   8.8169552e-02  6.3931167e-02]
 [-3.7674177e-02  1.5365051e-02  2.1673881e-01 ... -8.2050276e+00
   2.6932850e+00 -4.4459348e+00]
 [-8.1146188e-02  1.5728043e-02  2.3461881e-01 ... -1.2342579e+00
  -4.2705555e+00 -3.4320550e+00]
 ...
 [-5.6956607e-01  3.2922680e+00 -3.1815246e-01 ... -3.0207298e+00
   2.2088470e+00 -4.5268812e+00]
 [-5.7592243e-01  3.3024230e+00 -2.3679866e-01 ...  9.0879889e+00
  -1.4607348e+00 -2.9853027e+00]
 [-5.6932127e-01  3.2941539e+00 -2.2014110e-01 ... -6.1263404e+00
   3.7058628e+00  1.0001764e+01]]
train: [[-1.1576745e-02  2.9706063e-02 -4.3819766e-02 ... -1.5118380e-01
   2.1798883e-02  8.2527779e-02]
 [-2.4950452e-02  1.0053803e-02 -3.0048784e-02 ...  2.8199830e+00
   2.1435523e+00 -7.0969844e-01]
 [-5.7788070e-02 -2.7738808e-04  1.7422727e-01 ... -2.3943918e+00
  -2.0879815e+00  7.5078583e+00]
 ...
 [-1.4442080e-01 -1.4363490e-01 -5.4363328e-01 ... -8.4365475e-01
   6.1477447e-01  1.

In [11]:
# check shuffle

array = [{'a': [1,2], 'b': [11,12], 'c': [21,22], 'd': [31,32]}, \
        {'a': [3,4], 'b': [13,14], 'c': [23,24], 'd': [33,34]}, \
        {'a': [5,6], 'b': [15,16], 'c': [25,26], 'd': [35,36]}, \
        {'a': [7,8], 'b': [17,18], 'c': [27,28], 'd': [37,38]}, \
        {'a': [9,0], 'b': [19,10], 'c': [29,20], 'd': [39,30]}]
        
np.random.shuffle(array)

print(array)

[{'a': [9, 0], 'b': [19, 10], 'c': [29, 20], 'd': [39, 30]}, {'a': [1, 2], 'b': [11, 12], 'c': [21, 22], 'd': [31, 32]}, {'a': [7, 8], 'b': [17, 18], 'c': [27, 28], 'd': [37, 38]}, {'a': [5, 6], 'b': [15, 16], 'c': [25, 26], 'd': [35, 36]}, {'a': [3, 4], 'b': [13, 14], 'c': [23, 24], 'd': [33, 34]}]


In [12]:
# check valid data shape
with open(VAL_DATA_PATH, 'rb') as f:
    val_trajectories = pickle.load(f)

print("length: ", len(val_trajectories)*len(val_trajectories[0]['observations']))
print("n of epi: ", len(val_trajectories))
print("n of traj in one epi: ", len(val_trajectories[0]['observations']))
# print("val state shape: ", val_trajectories['observations'].shape)
# print("val action shape: ", val_trajectories['actions'].shape)
# print("val reward shape: ", val_trajectories['rewards'].shape)

length:  41000
n of epi:  41
n of traj in one epi:  1000


In [13]:
# train parameter
batch_size = 128
embed_dim = 128
activation = 'relu'
drop_out = 0.1
k = 31 # content len
n_blocks = 3
n_heads = 1 # transformer head

# total updates = max_train_iters x num_updates_per_iter
max_train_iters = 1000
num_updates_per_iter = 100
# num_val_iter = 100
total_updates = 0
min_total_log_loss = 1e10

wt_decay = 1e-4             # weight decay
lr = 1e-4                   # learning rate
warmup_steps = 10000        # warmup steps for lr scheduler

# weight of mse loss
state_weight = 1
reward_weight = 1

# evaluation parameter
# max_eval_ep_len = 1000      # max len of one evaluation episode
# num_eval_ep = 10            # num of evaluation episodes per iteration

In [14]:
# check dim

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

print("state dim: ", state_dim)
print("action dim: ", act_dim)

state dim:  17
action dim:  6


In [15]:
# train data loader tester
# test_traj_dataset = D4RLTrajectoryDataset(TRAIN_DATA_PATH, k)
# test_traj_data_loader = DataLoader(test_traj_dataset,
# 						batch_size=batch_size,
# 						shuffle=True,
# 						pin_memory=True,
# 						drop_last=True)
                        
# test_data_iter = iter(test_traj_data_loader)


# for i_train_iter in tqdm(range(max_train_iters)):
	
# 	for _ in range(num_updates_per_iter):
# 		try:
# 			timesteps, states, next_states, actions, rewards, traj_mask = next(test_data_iter)
# 		except StopIteration:
# 			test_traj_data_loader = DataLoader(test_traj_dataset,
# 									batch_size=batch_size,
# 									shuffle=True,
# 									pin_memory=True,
# 									drop_last=True)
# 			test_data_iter = iter(test_traj_data_loader)
# 			timesteps, states, next_states, actions, rewards, traj_mask = next(test_data_iter)

In [16]:
# # load validate preprocessing(normalization, fit padding) data

val_traj_dataset = D4RLTrajectoryDataset(TRAIN_DATA_PATH, k, val=True, val_dataset_path=VAL_DATA_PATH)

batch_size = check_batch(batch_size, len(val_traj_dataset))

print("batch_size:", batch_size)

batch_size: 32


In [17]:
# load train preprocessing(normalization, fit padding) data

train_traj_dataset = D4RLTrajectoryDataset(TRAIN_DATA_PATH, k)
train_traj_data_loader = DataLoader(train_traj_dataset,
						batch_size=batch_size,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
                        
train_data_iter = iter(train_traj_data_loader)

## get state stats from dataset
state_mean, state_std = train_traj_dataset.get_state_stats()

In [18]:
# define model

model = DecisionTransformer(
			state_dim=state_dim,
			act_dim=act_dim,
			n_blocks=n_blocks,
			h_dim=embed_dim,
			context_len=k,
			n_heads=n_heads,
			drop_p=drop_out,
		).to(DEVICE)
  
optimizer = torch.optim.AdamW(
					model.parameters(), 
					lr=lr, 
					weight_decay=wt_decay
				)

scheduler = torch.optim.lr_scheduler.LambdaLR(
		optimizer,
		lambda steps: min((steps+1)/warmup_steps, 1)
	)
	


In [None]:
start_time = datetime.now().replace(microsecond=0)

start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

prefix = f"{env_name}-{dataset}"

save_model_name =  f'{prefix}_model.pt'
save_best_model_name = f'{prefix}_model_best.pt'
save_model_path = os.path.join(LOG_PATH, save_model_name)
save_best_model_path = os.path.join(BEST_MODEL_PATH, save_best_model_name)

log_csv_name = prefix + "_log_" + start_time_str + ".csv"
log_csv_path = os.path.join(LOG_PATH, log_csv_name)


csv_writer = csv.writer(open(log_csv_path, 'a', 1))
csv_header = (["duration", "num_updates", "total_loss", "state_loss", "reward_loss", "val_total_loss", "val_state_loss", "val_reward_loss"])

csv_writer.writerow(csv_header)


print("=" * 60)
print("start time: " + start_time_str)
print("=" * 60)

print("device set to: " + str(DEVICE))
print("dataset: " + prefix)
print("batch_size: " + str(batch_size))
print("best model save path: " + save_best_model_path)
print("log csv save path: " + log_csv_path)

# train
for i_train_iter in tqdm(range(max_train_iters)):


	log_state_losses, log_reward_losses, log_total_losses = [], [], []
	val_log_state_losses, val_log_reward_losses, val_log_total_losses = [], [], []
	model.train()
	
	for _ in range(num_updates_per_iter):
		try:
			timesteps, states, next_states, actions, rewards, traj_mask = next(train_data_iter)
		except StopIteration:
			train_traj_data_loader = DataLoader(train_traj_dataset,
									batch_size=batch_size,
									shuffle=True,
									pin_memory=True,
									drop_last=True)
			train_data_iter = iter(train_traj_data_loader)
			timesteps, states, next_states, actions, rewards, traj_mask = next(train_data_iter)

		timesteps = timesteps.to(DEVICE)	# B x T
		states = states.to(DEVICE)			# B x T x state_dim
		next_states = next_states.to(DEVICE) # B X T X state_dim
		actions = actions.to(DEVICE)		# B x T x act_dim
		rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
		traj_mask = traj_mask.to(DEVICE)	# B x T

		next_states_target = torch.clone(next_states).detach().to(DEVICE)
		rewards_target = torch.clone(rewards).detach().to(DEVICE)
	
		next_state_preds, rewards_preds = model.forward(
														timesteps=timesteps,
														states=states,
														actions=actions,
														rewards=rewards,
													)

		# only consider non padded elements
		next_state_preds = next_state_preds.view(-1, state_dim)[traj_mask.view(-1,) > 0]
		next_states_target = next_states_target.view(-1, state_dim)[traj_mask.view(-1,) > 0]
		
		rewards_preds = rewards_preds.view(-1, 1)[traj_mask.view(-1,) > 0]
		rewards_target = rewards_target.view(-1, 1)[traj_mask.view(-1,) > 0]

		state_loss = F.mse_loss(next_state_preds, next_states_target, reduction='mean') * state_weight
		reward_loss = F.mse_loss(rewards_preds, rewards_target, reduction='mean') * reward_weight
		
		total_loss = state_loss.add(reward_loss)
		total_loss = torch.mean(total_loss)

		optimizer.zero_grad()
		total_loss.backward()
		torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
		optimizer.step()
		scheduler.step()
		
		
		#save loss
		log_state_losses.append(state_loss.detach().cpu().item())
		log_reward_losses.append(reward_loss.detach().cpu().item())
		
		log_total_losses.append(total_loss.detach().cpu().item())
		
	# validation
	model.eval()
	val_traj_data_loader = DataLoader(val_traj_dataset,
						batch_size=batch_size,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
	for val_timesteps, val_states, val_next_states, val_actions, val_rewards, val_traj_mask in val_traj_data_loader:
		
		val_timesteps = val_timesteps.to(DEVICE)	# B x T
		val_states = val_states.to(DEVICE)			# B x T x state_dim
		val_next_states = val_next_states.to(DEVICE) # B X T X state_dim
		val_actions = val_actions.to(DEVICE)		# B x T x act_dim
		val_rewards = val_rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
		val_traj_mask = val_traj_mask.to(DEVICE)	# B x T
				
		val_next_states_target = torch.clone(val_next_states).detach().to(DEVICE)
		val_rewards_target = torch.clone(val_rewards).detach().to(DEVICE)
		
		val_next_state_preds, val_rewards_preds = model.forward(
														timesteps=val_timesteps,
														states=val_states,
														actions=val_actions,
														rewards=val_rewards,
													)
													
		# only consider non padded elements
		val_next_state_preds = val_next_state_preds.view(-1, state_dim)[val_traj_mask.view(-1,) > 0]
		val_next_states_target = val_next_states_target.view(-1, state_dim)[val_traj_mask.view(-1,) > 0]
		
		val_rewards_preds = val_rewards_preds.view(-1, 1)[val_traj_mask.view(-1,) > 0]
		val_rewards_target = val_rewards_target.view(-1, 1)[val_traj_mask.view(-1,) > 0]

		val_state_loss = F.mse_loss(val_next_state_preds, val_next_states_target, reduction='mean') * state_weight
		val_reward_loss = F.mse_loss(val_rewards_preds, val_rewards_target, reduction='mean') * reward_weight

		# todo: try to use mae
		
		val_total_loss = val_state_loss.add(val_reward_loss)
		val_total_loss = torch.mean(val_total_loss)
		
		# save val loss
		val_log_state_losses.append(val_state_loss.detach().cpu().item())
		val_log_reward_losses.append(val_reward_loss.detach().cpu().item())
		
		val_log_total_losses.append(val_total_loss.detach().cpu().item())
	
	mean_total_log_loss = np.mean(log_total_losses)
	mean_state_log_loss = np.mean(log_state_losses)
	mean_reward_log_loss = np.mean(log_reward_losses)
	
	mean_val_total_log_loss = np.mean(val_log_total_losses)
	mean_val_state_log_loss = np.mean(val_log_state_losses)
	mean_val_reward_log_loss = np.mean(val_log_reward_losses)

	time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)

	total_updates += num_updates_per_iter

	log_str = ("=" * 60 + '\n' +
			"time elapsed: " + time_elapsed  + '\n' +
			"num of updates: " + str(total_updates) + '\n' +
			"train total loss: " + format(mean_total_log_loss, ".5f") + '\n' +
			"train state loss: " + format(mean_state_log_loss, ".5f") + '\n' +
			"train reward loss: " +  format(mean_reward_log_loss, ".5f") + '\n' +
			"val total loss: " + format(mean_val_total_log_loss, ".5f") + '\n' +
			"val state loss: " + format(mean_val_state_log_loss, ".5f") + '\n' +
			"val reward loss: " +  format(mean_val_reward_log_loss, ".5f")
			)

	print(log_str)

	log_data = [time_elapsed, total_updates, mean_total_log_loss, mean_state_log_loss, mean_reward_log_loss, \
		 mean_val_total_log_loss, mean_val_state_log_loss, mean_val_reward_log_loss]

	csv_writer.writerow(log_data)
	
	# save model
	if mean_val_total_log_loss <= min_total_log_loss:
		print("saving min loss model at: " + save_best_model_path)
		torch.save(model.state_dict(), save_best_model_path)
		min_total_log_loss = mean_val_total_log_loss

	print("saving current model at: " + save_model_path)
	torch.save(model.state_dict(), save_model_path)


print("=" * 60)
print("finished training!")
print("=" * 60)
end_time = datetime.now().replace(microsecond=0)
time_elapsed = str(end_time - start_time)
end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
print("started training at: " + start_time_str)
print("finished training at: " + end_time_str)
print("total training time: " + time_elapsed)
print("saved best model at: " + save_best_model_path)
print("saved last updated model at: " + save_model_path)
print("=" * 60)