In [1]:
# import library

import sys
import os
import d4rl
import gym
import numpy as np
import wandb
import collections
import pickle
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F

from datetime import datetime
from torch.utils.data import Dataset, DataLoader

from utils import discount_cumsum, D4RLTrajectoryDataset, evaluate_on_env, get_d4rl_normalized_score
from model import MaskedCausalAttention, Block, DecisionTransformer


No module named 'flow'
No module named 'carla'


In [2]:
# set environment
# sys.path.append(r'C:\Develop\offlineRL-with-diffusion') 

In [3]:
# test mujoco, d4rl

!python ./test/mujoco_test.py

mujoco-py check passed
d4rl check passed


No module named 'flow'
No module named 'carla'
pybullet build time: Apr 30 2024 12:01:25
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [4]:
# data download
# if you downloaded, don't re-start.

# !python ./data/download_d4rl_datasets.py


In [5]:
# parameter setting

env_name = 'halfcheetah'
dataset = 'medium'

if env_name == 'hopper':
    env = gym.make('Hopper-v3')
    max_ep_len = 1000
    # env_targets = [3600, 1800]  # evaluation conditioning targets
    scale = 1000.  # normalization for rewards/returns
elif env_name == 'halfcheetah':
    env = gym.make('HalfCheetah-v3')
    max_ep_len = 1000
    # env_targets = [12000, 6000]
    scale = 1000.
elif env_name == 'walker2d':
    env = gym.make('Walker2d-v3')
    max_ep_len = 1000
    # env_targets = [5000, 2500]
    scale = 1000.
    
DATA_PATH = f'data/{env_name}-{dataset}-v2.pkl'
LOG_PATH = "./log/"
DEVICE = 'cpu'

In [6]:
# data check
# check three trajectories

with open(DATA_PATH, 'rb') as f:
    trajectories = pickle.load(f)
n=0
max_rewards_list = []
for traj in trajectories:
    # print(f"{n+1}번째 trajectory")
    print("state: ", traj['observations'], "\n")
    # print("action: ", traj['actions'], "\n")
    print("next_state: ", traj['next_observations'], "\n")
    # print("reward: ", traj['rewards'], "\n")
    # print("max_rewards: ", max(traj['rewards']))
    # max_rewards_list.append(max(traj['rewards']))
    # print("")
    n+=1
    
# print(max(max_rewards_list))

    if n==3:
        break

state:  [[ 1.9831914e-02 -8.9501314e-02 -3.1969063e-03 ...  1.1365079e-01
   6.8424918e-02 -1.3811582e-01]
 [-3.8486063e-03 -5.2394319e-02  8.3050327e-03 ...  4.5068407e+00
  -9.2885571e+00  4.7328596e+00]
 [-5.5298433e-02 -7.7850236e-05 -2.3952831e-01 ... -7.0811687e+00
  -1.4037068e+00  7.5524049e+00]
 ...
 [-3.1975684e-01  5.3305399e-01 -4.8704177e-01 ...  1.5455554e+00
   2.6812897e+00  8.7905388e+00]
 [-3.2200974e-01  3.5745117e-01  1.0463273e-02 ... -6.3428599e-01
   1.6292539e+00  9.7356015e-01]
 [-3.0673215e-01  1.9843711e-01  6.9996923e-01 ...  5.0098950e-01
   1.5680059e+00  9.4733723e-02]] 

next_state:  [[-3.8486063e-03 -5.2394319e-02  8.3050327e-03 ...  4.5068407e+00
  -9.2885571e+00  4.7328596e+00]
 [-5.5298433e-02 -7.7850236e-05 -2.3952831e-01 ... -7.0811687e+00
  -1.4037068e+00  7.5524049e+00]
 [-1.2996776e-01  2.2959358e-03 -2.2985412e-01 ... -7.0144100e+00
   2.6917322e+00 -1.6729002e+00]
 ...
 [-3.2200974e-01  3.5745117e-01  1.0463273e-02 ... -6.3428599e-01
   1.6292

In [16]:
# train parameter
batch_size = 64
embed_dim = 128
activation = 'relu'
drop_out = 0.1
k = 20
n_blocks = 3    
n_heads = 1 # transformer head

# total updates = max_train_iters x num_updates_per_iter
max_train_iters = 200
num_updates_per_iter = 100
total_updates = 0
max_d4rl_score = -1.0

wt_decay = 1e-4             # weight decay
lr = 1e-4                   # learning rate
warmup_steps = 10000        # warmup steps for lr scheduler

# weight of mse loss
state_weight = 1
reward_weight = 1

# evaluation parameter
max_eval_ep_len = 1000      # max len of one evaluation episode
num_eval_ep = 10            # num of evaluation episodes per iteration

In [8]:
# load preprocessing(normalization, fit padding) data

traj_dataset = D4RLTrajectoryDataset(DATA_PATH, k)
traj_data_loader = DataLoader(traj_dataset,
						batch_size=batch_size,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
                        
data_iter = iter(traj_data_loader)

## get state stats from dataset
state_mean, state_std = traj_dataset.get_state_stats()

In [9]:
# make environment

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

print("state dim: ", state_dim)
print("action dim: ", act_dim)

state dim:  17
action dim:  6


In [10]:
# test data
temp_dataset = D4RLTrajectoryDataset(DATA_PATH, 2)
temp_data_loader = DataLoader(temp_dataset,
						batch_size=4,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
                        
temp_data_iter = iter(temp_data_loader)

timesteps, states, next_states, actions, rewards, traj_mask = next(temp_data_iter)

timesteps = timesteps.to(DEVICE)	# B x T
states = states.to(DEVICE)			# B x T x state_dim
next_states = next_states.to(DEVICE) # B X T X state_dim
actions = actions.to(DEVICE)		# B x T x act_dim
rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1

print("timesteps shape: ", timesteps.shape)
print("rewards shape: ", rewards.shape)
print("states shape: ", states.shape)
print("actions shape: ", actions.shape)

# print("state: ", states)
# print("action: ", actions)
# print("rewards: ", rewards)



timesteps shape:  torch.Size([4, 2])
rewards shape:  torch.Size([4, 2, 1])
states shape:  torch.Size([4, 2, 17])
actions shape:  torch.Size([4, 2, 6])


In [11]:
# test model

temp_model = DecisionTransformer(
			state_dim=state_dim,
			act_dim=act_dim,
			# reward 포함 + r0 제외
			n_blocks=n_blocks,
			h_dim=16,
			context_len=2,
			n_heads=n_heads,
			drop_p=drop_out,
		).to(DEVICE)
		
next_state_preds, rewards_preds = temp_model.forward(
												rewards=rewards,
												timesteps=timesteps,
												states=states,
												actions=actions,
											)

In [12]:
# define model

model = DecisionTransformer(
			state_dim=state_dim,
			act_dim=act_dim,
			n_blocks=n_blocks,
			h_dim=embed_dim,
			context_len=k,
			n_heads=n_heads,
			drop_p=drop_out,
		).to(DEVICE)
  
optimizer = torch.optim.AdamW(
					model.parameters(), 
					lr=lr, 
					weight_decay=wt_decay
				)

scheduler = torch.optim.lr_scheduler.LambdaLR(
		optimizer,
		lambda steps: min((steps+1)/warmup_steps, 1)
	)
	


In [17]:
start_time = datetime.now().replace(microsecond=0)

start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

prefix = "dt_" + env_name

save_model_name =  prefix + "_model_" + start_time_str + ".pt"
save_model_path = os.path.join(LOG_PATH, save_model_name)
save_best_model_path = save_model_path[:-3] + "_best.pt"

log_csv_name = prefix + "_log_" + start_time_str + ".csv"
log_csv_path = os.path.join(LOG_PATH, log_csv_name)


csv_writer = csv.writer(open(log_csv_path, 'a', 1))
csv_header = (["duration", "num_updates", "total_loss", "state_loss", "reward_loss", 
               "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"])

csv_writer.writerow(csv_header)


print("=" * 60)
print("start time: " + start_time_str)
print("=" * 60)

print("device set to: " + str(DEVICE))
print("dataset path: " + DATA_PATH)
print("model save path: " + save_model_path)
print("log csv save path: " + log_csv_path)

# train
for i_train_iter in range(max_train_iters):


	log_state_losses, log_reward_losses, log_total_losses = [], [], []
	model.train()
 
	for _ in range(num_updates_per_iter):
		try:
			timesteps, states, next_states, actions, rewards, traj_mask = next(data_iter)
		except StopIteration:
			data_iter = iter(traj_data_loader)
			timesteps, states, next_states, actions, rewards, traj_mask = next(data_iter)

		timesteps = timesteps.to(DEVICE)	# B x T
		states = states.to(DEVICE)			# B x T x state_dim
		next_states = next_states.to(DEVICE) # B X T X state_dim
		actions = actions.to(DEVICE)		# B x T x act_dim
		rewards = rewards.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
		traj_mask = traj_mask.to(DEVICE)	# B x T

		next_states_target = torch.clone(next_states).detach().to(DEVICE)
		rewards_target = torch.clone(rewards).detach().to(DEVICE)
	
		next_state_preds, rewards_preds = model.forward(
														timesteps=timesteps,
														states=states,
														actions=actions,
														rewards=rewards,
													)

		# only consider non padded elements
		next_state_preds = next_state_preds.view(-1, state_dim)[traj_mask.view(-1,) > 0]
		next_states_target = next_states_target.view(-1, state_dim)[traj_mask.view(-1,) > 0]
		
		rewards_preds = rewards_preds.view(-1, 1)[traj_mask.view(-1,) > 0]
		rewards_target = rewards_target.view(-1, 1)[traj_mask.view(-1,) > 0]

		state_loss = F.mse_loss(next_state_preds, next_states_target, reduction='mean') * state_weight
		reward_loss = F.mse_loss(rewards_preds, rewards_target, reduction='mean') * reward_weight
		
		total_loss = state_loss.add(reward_loss)
		total_loss = torch.mean(total_loss)

		optimizer.zero_grad()
		total_loss.backward()
		torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
		optimizer.step()
		scheduler.step()
		
		
		#save loss
		log_state_losses.append(state_loss.detach().cpu().item())
		log_reward_losses.append(reward_loss.detach().cpu().item())
		
		log_total_losses.append(total_loss.detach().cpu().item())
		
	# evaluate on env
	results = evaluate_on_env(model, DEVICE, k, env, num_eval_ep, max_eval_ep_len, state_mean, state_std)
	eval_avg_reward = results['eval/avg_reward']
	eval_avg_ep_len = results['eval/avg_ep_len']
	eval_d4rl_score = get_d4rl_normalized_score(results['eval/avg_reward'], env_name) * 100

	mean_total_log_loss = np.mean(log_total_losses)
	mean_state_log_loss = np.mean(log_state_losses)
	mean_reward_log_loss = np.mean(log_reward_losses)
	
	time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)

	total_updates += num_updates_per_iter

	log_str = ("=" * 60 + '\n' +
			"time elapsed: " + time_elapsed  + '\n' +
			"num of updates: " + str(total_updates) + '\n' +
			"total loss: " + format(mean_total_log_loss, ".5f") + '\n' +
			"state loss: " + format(mean_state_log_loss, ".5f") + '\n' +
			"reward loss: " +  format(mean_reward_log_loss, ".5f") + '\n' +
			"eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' +
			"eval avg ep len: " + format(eval_avg_ep_len, ".5f") + '\n' +
			"eval d4rl score: " + format(eval_d4rl_score, ".5f")
			)

	print(log_str)

	log_data = [time_elapsed, total_updates, mean_total_log_loss,
				eval_avg_reward, eval_avg_ep_len,
				eval_d4rl_score]

	csv_writer.writerow(log_data)
	
	# save model
	print("max d4rl score: " + format(max_d4rl_score, ".5f"))
	if eval_d4rl_score >= max_d4rl_score:
		print("saving max d4rl score model at: " + save_best_model_path)
		torch.save(model.state_dict(), save_best_model_path)
		max_d4rl_score = eval_d4rl_score

	print("saving current model at: " + save_model_path)
	torch.save(model.state_dict(), save_model_path)


print("=" * 60)
print("finished training!")
print("=" * 60)
end_time = datetime.now().replace(microsecond=0)
time_elapsed = str(end_time - start_time)
end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
print("started training at: " + start_time_str)
print("finished training at: " + end_time_str)
print("total training time: " + time_elapsed)
print("max d4rl score: " + format(max_d4rl_score, ".5f"))
print("saved max d4rl score model at: " + save_best_model_path)
print("saved last updated model at: " + save_model_path)
print("=" * 60)

csv_writer.close()

KeyboardInterrupt: 