In [2]:
# import library

import sys
import os
import d4rl
import gym
import numpy as np
import wandb
import collections
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import discount_cumsum, D4RLTrajectoryDataset
from model import MaskedCausalAttention, Block, DecisionTransformer


In [18]:
# set environment
# sys.path.append(r'C:\Develop\offlineRL-with-diffusion') 

In [3]:
# test mujoco, d4rl

!python ./test/mujoco_test.py

mujoco-py check passed
d4rl check passed


No module named 'flow'
No module named 'carla'
pybullet build time: Apr 30 2024 12:01:25
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [18]:
# data download
# if you downloaded, don't re-start.

# !python ./data/download_d4rl_datasets.py


<mujoco._structs.MjData object at 0x000002451ABFF0F0>


array([ 0.09314913, -0.02171162,  0.08267734,  0.01693893, -0.01558907,
        0.08933625,  0.05449321,  0.04947496, -0.02719941,  0.01293334,
       -0.09930255,  0.12025792,  0.08570378,  0.01392469, -0.00379726,
        0.08063133, -0.04713409])

In [4]:
# parameter setting

env_name = 'halfcheetah'
dataset = 'medium'

if env_name == 'hopper':
    env = gym.make('Hopper-v3')
    max_ep_len = 1000
    env_targets = [3600, 1800]  # evaluation conditioning targets
    scale = 1000.  # normalization for rewards/returns
elif env_name == 'halfcheetah':
    env = gym.make('HalfCheetah-v3')
    max_ep_len = 1000
    env_targets = [12000, 6000]
    scale = 1000.
elif env_name == 'walker2d':
    env = gym.make('Walker2d-v3')
    max_ep_len = 1000
    env_targets = [5000, 2500]
    scale = 1000.
    
DATA_PATH = f'data/{env_name}-{dataset}-v2.pkl'

DEVICE = 'cuda'

In [5]:
# data check
# check three trajectories

with open(DATA_PATH, 'rb') as f:
    trajectories = pickle.load(f)
n=0
for traj in trajectories:
    print(f"{n+1}번째 trajectory")
    print("state: ", traj['observations'], "\n")
    print("action: ", traj['actions'], "\n")
    print("next_state: ", traj['next_observations'], "\n")
    print("reward: ", traj['rewards'], "\n")
    print("")
    n+=1
    
    if n==3:
        break

1번째 trajectory
state:  [[ 1.9831914e-02 -8.9501314e-02 -3.1969063e-03 ...  1.1365079e-01
   6.8424918e-02 -1.3811582e-01]
 [-3.8486063e-03 -5.2394319e-02  8.3050327e-03 ...  4.5068407e+00
  -9.2885571e+00  4.7328596e+00]
 [-5.5298433e-02 -7.7850236e-05 -2.3952831e-01 ... -7.0811687e+00
  -1.4037068e+00  7.5524049e+00]
 ...
 [-3.1975684e-01  5.3305399e-01 -4.8704177e-01 ...  1.5455554e+00
   2.6812897e+00  8.7905388e+00]
 [-3.2200974e-01  3.5745117e-01  1.0463273e-02 ... -6.3428599e-01
   1.6292539e+00  9.7356015e-01]
 [-3.0673215e-01  1.9843711e-01  6.9996923e-01 ...  5.0098950e-01
   1.5680059e+00  9.4733723e-02]] 

action:  [[-0.22293739 -0.7359478  -0.8599511   0.29579234 -0.8416547   0.43432042]
 [-0.72515714 -0.9237766  -0.9887563  -0.5074778  -0.931533    0.9786059 ]
 [ 0.85085297  0.4005884  -0.9951021  -0.9082174  -0.9951816   0.74173075]
 ...
 [ 0.99994427  0.98067576  0.97369236  0.98475415 -0.7468327  -0.8671992 ]
 [ 0.9979593   0.73803234  0.97451997  0.9981649  -0.92976135

In [7]:
# load preprocessing(normalization, fit padding) data

traj_dataset = D4RLTrajectoryDataset(DATA_PATH, k, scale)
traj_data_loader = DataLoader(traj_dataset,
						batch_size=batch_size,
						shuffle=True,
						pin_memory=True,
						drop_last=True)
                        
data_iter = iter(traj_data_loader)

## get state stats from dataset
state_mean, state_std = traj_dataset.get_state_stats()

In [None]:
# make environment
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

In [None]:
# define model

model = DecisionTransformer(
			state_dim=state_dim,
			act_dim=act_dim,
			n_blocks=n_blocks,
			h_dim=embed_dim,
			context_len=context_len,
			n_heads=n_heads,
			drop_p=drop_out,
		).to(DEVICE)
  
optimizer = torch.optim.AdamW(
					model.parameters(), 
					lr=lr, 
					weight_decay=wt_decay
				)

scheduler = torch.optim.lr_scheduler.LambdaLR(
		optimizer,
		lambda steps: min((steps+1)/warmup_steps, 1)
	)
	


In [None]:
# train parameter
batch_size = 64
embed_dim = 128
activation = 'relu'
drop_out = 0.1
k = 20
n_blocks = 3    
n_heads = 1 # transformer head

# total updates = max_train_iters x num_updates_per_iter
max_train_iters = 200
num_updates_per_iter = 100
total_updates = 0
max_d4rl_score = -1.0

In [None]:
# train
for i_train_iter in range(max_train_iters):

	log_action_losses = []	
	model.train()
 
	for _ in range(num_updates_per_iter):
		try:
			timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)
		except StopIteration:
			data_iter = iter(traj_data_loader)
			timesteps, states, actions, returns_to_go, traj_mask = next(data_iter)

		timesteps = timesteps.to(DEVICE)	# B x T
		states = states.to(DEVICE)			# B x T x state_dim
		actions = actions.to(DEVICE)		# B x T x act_dim
		returns_to_go = returns_to_go.to(DEVICE).unsqueeze(dim=-1) # B x T x 1
		traj_mask = traj_mask.to(DEVICE)	# B x T

		state_target = torch.clone(next_state).detach().to(DEVICE) # next_state를 뽑을 수 있게 데이터 처리 변경해야함
		rtg_target = torch.clone(returns_to_go).detach().to(DEVICE)
	
		state_preds, action_preds, return_preds = model.forward(
														timesteps=timesteps,
														states=states,
														actions=actions,
														returns_to_go=returns_to_go
													)

		# only consider non padded elements
		state_preds = state_preds.view(-1, state_dim)[traj_mask.view(-1,) > 0]
		state_target = state_target.view(-1, state_dim)[traj_mask.view(-1,) > 0]

		state_loss = F.mse_loss(state_preds, state_target, reduction='mean')
		reward_loss = ''
		
		total_loss = '' # state의 loss와 reward의 loss의 평균

		optimizer.zero_grad()
		total_loss.backward()
		torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
		optimizer.step()
		scheduler.step()

		log_action_losses.append(total_loss.detach().cpu().item())