# Cube Env Training

## Import configurations

In [4]:
import yaml
import torch as th

config = yaml.full_load(open('args/dqn_config_default.yaml'))

model_kwargs = config['model']
policy_kwargs = config['policy']
learning_kwargs = config['learning']

### default config
- policy kwargs

In [5]:
from pprint import pprint

pprint(policy_kwargs)

{'features_extractor_class': None,
 'features_extractor_kwargs': None,
 'net_arch': None,
 'normalize_images': True,
 'optimizer_class': <class 'torch.optim.adam.Adam'>,
 'optimizer_kwargs': None}


- model kwargs

In [6]:
pprint(model_kwargs)

{'_init_setup_model': True,
 'batch_size': 32,
 'buffer_size': 1000000,
 'create_eval_env': False,
 'device': 'auto',
 'exploration_final_eps': 0.05,
 'exploration_fraction': 0.1,
 'exploration_initial_eps': 1.0,
 'gamma': 0.99,
 'gradient_steps': 1,
 'learning_rate': 0.0001,
 'learning_starts': 50000,
 'max_grad_norm': 10.0,
 'optimize_memory_usage': False,
 'policy_kwargs': {'features_extractor_class': None,
                   'features_extractor_kwargs': None,
                   'net_arch': None,
                   'normalize_images': True,
                   'optimizer_class': <class 'torch.optim.adam.Adam'>,
                   'optimizer_kwargs': None},
 'replay_buffer_class': None,
 'replay_buffer_kwargs': None,
 'seed': None,
 'target_update_interval': 10000,
 'tau': 1.0,
 'tensorboard_log': None,
 'train_freq': 4,
 'verbose': 0}


- learning kwargs

In [7]:
pprint(learning_kwargs)

{'callback': None,
 'eval_env': None,
 'eval_freq': 100000,
 'eval_log_path': None,
 'log_interval': 100,
 'n_eval_episodes': 5,
 'reset_num_timesteps': True,
 'tb_log_name': None}


## optimize configurations

### Env setting
- environment<br>
    `CubeEnv(3)`
- observation space<br>
    `MultiDiscrete([6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6
 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6])`<br>
  = `MultiDiscrete(np.zeros((6,3,3)).flatten() + 6)`<br>
  = 3x3 크기의 큐브면 6개를 flatten한 array

  <details><summary>flatten 이유</summary>

    `MultiDiscrete` space는 2차원 형식의 traning을 지원하지 않았음 <br>
    (어차피 `feature_extractor`가 2차원을 지원하게 만들어도, `basepolicy`에서 obs를 preprocess할 때 2차원 형식은 env의 vectorization 이유로 처리 불가)
  </details>

- action space<br>
    `Discrete(12)`<br>
     6개의 면에 대한 회전 * 2개 방향

- preprocessed observation space<br>
    `one_hot(Env.observation_space)`<br>
    size: (1,324)

## Training setting

In [8]:
from Sample.Cube.Env.cube_env import CubeEnv

env = CubeEnv(3)
eval_evn = CubeEnv(3)

### Training set 1
- FlattenExtractor<br>
    just flatten layer

In [13]:
from stable_baselines3.dqn.policies import FlattenExtractor
from time import ctime

policy_kwargs1 = policy_kwargs.copy()
model_kwargs1 = model_kwargs.copy()
learning_kwargs1 = learning_kwargs.copy()

policy_kwargs1['features_extractor_class'] = FlattenExtractor

def learning_rate_fun(x):
    return 0.01 * (0.5 ** (8*(1-x)))
model_kwargs1['learning_rate'] = learning_rate_fun
model_kwargs1['buffer_size'] = 1000
model_kwargs1['learning_starts'] = 100
model_kwargs1['batch_size'] = 16
model_kwargs1['tau'] = 0.99
model_kwargs1['train_freq'] = 1
model_kwargs1['gradient_steps'] = model_kwargs['train_freq']
model_kwargs1['target_update_interval'] = 10
model_kwargs1['exploration_fraction'] = 0.2
model_kwargs1['tensorboard_log'] = 'logs/tb_log/' + ctime().replace(' ', '-')
model_kwargs1['policy_kwargs'] = policy_kwargs1
model_kwargs1['verbose'] = 1

learning_kwargs1['log_interval'] = 100
learning_kwargs1['eval_env'] = env
learning_kwargs1['eval_freq'] = 100
learning_kwargs1['n_eval_episodes'] = 5
learning_kwargs1['eval_log_path'] = 'logs/eval_log/'

In [None]:
from stable_baselines3.dqn import DQN
from stable_baselines3.dqn.policies import MlpPolicy

model = DQN(policy=MlpPolicy, env=env, **model_kwargs1)
model.learn(total_timesteps=25000, **learning_kwargs1)

### Save training_set_1

In [14]:
dqn_config_custom1 = config.copy()
dqn_config_custom1['policy'] = policy_kwargs1
dqn_config_custom1['model'] = model_kwargs1
dqn_config_custom1['learning'] = learning_kwargs1

In [15]:
with open('args/dqn_config_custom1.yaml', 'w') as f:
    f.write(yaml.dump(dqn_config_custom1))


In [16]:
tmp = yaml.unsafe_load(open('args/dqn_config_custom1.yaml'))


