SET PARAMS (RL_ALGO_ARG) and RUN

In [2]:
'''PARAMS'''
RL_ALGO_ARG = 'A2C'

'''CODE'''
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
from model.policy import *

torch.autograd.set_detect_anomaly(True)

#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 1) 설정 불러오기
cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
set_seed(cfg["train"]["seed"])

# 2) 환경 생성 & 래핑
#base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
base_env = CustomMazeEnv(**{'layout_id': 'c',
'goal_pos': [0, 3],
'view_size': 5,
'max_steps': 1000,
'tile_size': 32,
'render_mode': 'rgb_array'})
env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
action_dim = env.action_space.n

# 3) 에이전트, 옵티마이저
# hyperparameter
eps_start = cfg["agent"]["eps_start"]
eps_decay = cfg["agent"]["eps_decay"]
eps = eps_start

HIDDEN_SIZE = HIDDEN_SIZE = cfg["agent"]["hidden_size"]
PARAMS = {
    'memory_bank_ep': {
        'decay_rate': 0.0001, 
        'noise_std': 0.001, 
        'et_lambda': 0.99,
        'memory_len': 5000,
        'update_freq': 100,
        'hidden_dim': HIDDEN_SIZE,
        'decay_yn': False
    },
    'cnn_embed': {
        'cnn_hidden_lyrs': [4, 8],
        'lin_hidden_lyrs': [512, HIDDEN_SIZE],
        'input_img_shape': (60, 80)
    },
    'rnn': {
        'input_size': HIDDEN_SIZE,
        'hidden_size': HIDDEN_SIZE,
        'batch_first': True
    },
    'memory_gate': {
        'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
        'action_dim': action_dim,
        'attn_size': 5,
        'rl_algo_arg': RL_ALGO_ARG
    }
}

# policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)
policy = cnnrnnattn_policy(PARAMS)
optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))
memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])
hx = torch.randn(1, 1, HIDDEN_SIZE) / math.sqrt(HIDDEN_SIZE)

# tensorboard 준비
if cfg["logging"]["use_tensorboard"]:
    from torch.utils.tensorboard import SummaryWriter
    tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
else:
    tb = None

# 4) 학습 루프
policy.train()
for ep in range(1, cfg['train']['total_episodes']):
    
    obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
    # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
    retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
    # full_map = obs["image"] # 160 x 160 x 3
    state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

    log_probs, rewards, values = [], [], []
    gate_alpha_lst, terminated_lst = [], []

    #==========================
    # 에피소드 학습
    #==========================
    done = False
    while not done:
        logits, value, sx, hx, chosen_ids, gate_alpha_, attention_ = policy(state, hx, memory_bank_ep)
        m = Categorical(logits=logits)
        if random.random() < eps:
            a = torch.randint(action_dim, (1,))
        else:
            a = m.sample()
        log_probs.append(m.log_prob(a))
        values.append(value)
        obs, r, term, trunc, _ = env.step(a.item())
        rewards.append(r)

        memory_bank_ep.update(retina, sx.detach().clone(), hx.detach().clone(), a, r, obs, ep, chosen_ids)
        memory_bank_ep.save(cfg["logging"]["timestep_dir"], cfg["logging"]["attention_dir"], ep, obs['timestep'], chosen_ids)
        
        # next state
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        state = torch.from_numpy(retina)[None, None, ...]
        done = term or trunc

        with torch.no_grad():
            hx = hx.detach().clone()
            gate_alpha_lst.append(gate_alpha_.item())
            terminated_lst.append(term)
            timestep_ = obs['timestep']
            os.makedirs(os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}'), exist_ok=True)
            torch.save(attention_,
                        os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}', f'attention_weight_{timestep_}.pt'))

        ### for checking
        print(chosen_ids, gate_alpha_, r)

    #============================
    # 에피소드 당 손실 계산 (REINFORCE)
    #============================
    attn_size = PARAMS['memory_gate']['attn_size']
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + cfg["agent"]["gamma"] * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    returns = returns[attn_size:]
    log_probs = torch.stack(log_probs[attn_size:])

    if RL_ALGO_ARG == 'REINFORCE':
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        # loss = 0
        # for lp, G in zip(log_probs, returns):
        #     loss -= lp * G        
        
        loss = -(log_probs * returns).sum()

    elif RL_ALGO_ARG == 'A2C':
        values  = torch.stack(values[attn_size:])
        advantages = returns.detach() - values
        actor_loss = -(log_probs * advantages.detach()).mean()
        value_loss = advantages.pow(2).mean()
        alpha = cfg["train"]["actor_loss_coef"]
        loss = alpha * actor_loss + (1-alpha) * value_loss

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    eps *= eps_decay

    #============================
    # 로그
    #============================
    ep_reward = sum(rewards)
    print(ep_reward)

    if tb:
        tb.add_scalar("train/episode_reward", ep_reward, ep)
        tb.add_scalar("train/loss", loss.item(), ep)

    if ep % cfg["train"]["log_interval"] == 0:
        print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

    if ep % cfg["train"]["save_interval"] == 0:
        time_now = dt.now().strftime("%Y-%m-%d-%H:%M")
        os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
        torch.save(policy.state_dict(),
                    os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}_{time_now}.pt"))
    
    os.makedirs(cfg["logging"]["gate_alpha_dir"], exist_ok=True)
    os.makedirs(cfg["logging"]["terminated_dir"], exist_ok=True)
    with open(os.path.join(cfg["logging"]["gate_alpha_dir"], f"gate_alpha_ep{ep}.pkl"), "wb") as file_gate_alpha:
        pickle.dump(gate_alpha_lst, file_gate_alpha)
    with open(os.path.join(cfg["logging"]["terminated_dir"], f"terminated_ep{ep}.pkl"), "wb") as file_terminated:
        pickle.dump(terminated_lst, file_terminated)

    if tb:
        tb.close() 

tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([1, 3, 5, 4, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 5, 4, 6], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 5, 4, 7], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 5, 4, 8], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 5, 4, 9], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 10], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 10], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 10], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 10], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 10], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5, 15,  4], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 15], dtype=torch.int32) tensor([0.]) -0.01
tensor([ 1,  3,  5,  4, 15], dtype=torch.int32) tensor([0.]) -0.01
tens

KeyboardInterrupt: 

---------------------------IGNORE BELOW-----------------------------------------

In [None]:
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
from model.policy import *

torch.autograd.set_detect_anomaly(True)

#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def main(RL_ALGO_ARG):

    # 1) 설정 불러오기
    cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
    set_seed(cfg["train"]["seed"])

    # 2) 환경 생성 & 래핑
    #base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
    base_env = CustomMazeEnv(**{'layout_id': 'c',
    'goal_pos': [0, 3],
    'view_size': 5,
    'max_steps': 1000,
    'tile_size': 32,
    'render_mode': 'rgb_array'})
    env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
    # obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
    action_dim = env.action_space.n

    # 3) 에이전트, 옵티마이저
    # hyperparameter
    HIDDEN_SIZE = HIDDEN_SIZE = cfg["agent"]["hidden_size"]
    PARAMS = {
        'memory_bank_ep': {
            'decay_rate': 0.0001, 
            'noise_std': 0.001, 
            'et_lambda': 0.99,
            'memory_len': 5000,
            'update_freq': 100,
            'hidden_dim': HIDDEN_SIZE,
            'decay_yn': False
        },
        'cnn_embed': {
            'cnn_hidden_lyrs': [4, 8],
            'lin_hidden_lyrs': [512, HIDDEN_SIZE],
            'input_img_shape': (60, 80)
        },
        'rnn': {
            'input_size': HIDDEN_SIZE,
            'hidden_size': HIDDEN_SIZE,
            'batch_first': True
        },
        'memory_gate': {
            'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
            'action_dim': 4,
            'attn_size': 5,
            'rl_algo_arg': RL_ALGO_ARG
        }
    }

    # policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)
    policy = cnnrnnattn_policy(PARAMS)
    optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))
    memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])
    hx = torch.randn(1, 1, HIDDEN_SIZE) / math.sqrt(HIDDEN_SIZE)

    # tensorboard 준비
    if cfg["logging"]["use_tensorboard"]:
        from torch.utils.tensorboard import SummaryWriter
        tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
    else:
        tb = None

    # 4) 학습 루프
    policy.train()
    for ep in range(1, 5):
        
        obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
        # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        # full_map = obs["image"] # 160 x 160 x 3
        state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

        log_probs, rewards, values = [], [], []
        gate_alpha_lst, terminated_lst = [], []

        #==========================
        # 에피소드 학습
        #==========================
        done = False
        while not done:
            logits, value, sx, hx, chosen_ids, gate_alpha_, attention_ = policy(state, hx, memory_bank_ep)
            m = Categorical(logits=logits)
            a = m.sample()
            log_probs.append(m.log_prob(a))
            values.append(value)
            obs, r, term, trunc, _ = env.step(a.item())
            rewards.append(r)

            memory_bank_ep.update(retina, sx.detach().clone(), hx.detach().clone(), a, r, obs, ep, chosen_ids)
            memory_bank_ep.save(cfg["logging"]["timestep_dir"], cfg["logging"]["attention_dir"], ep, obs['timestep'], chosen_ids)
            
            # next state
            retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
            state = torch.from_numpy(retina)[None, None, ...]
            done = term or trunc

            with torch.no_grad():
                hx = hx.detach().clone()
                gate_alpha_lst.append(gate_alpha_.item())
                terminated_lst.append(term)
                timestep_ = obs['timestep']
                os.makedirs(os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}'), exist_ok=True)
                torch.save(attention_,
                            os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}', f'attention_weight_{timestep_}.pt'))

            ### for checking
            print(chosen_ids, gate_alpha_, r)

        #============================
        # 에피소드 당 손실 계산 (REINFORCE)
        #============================
        attn_size = PARAMS['memory_gate']['attn_size']
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + cfg["agent"]["gamma"] * G
            returns.insert(0, G)
        returns = torch.tensor(returns)
        returns = returns[attn_size:]
        log_probs = torch.stack(log_probs[attn_size:])

        if RL_ALGO_ARG == 'REINFORCE':
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            # loss = 0
            # for lp, G in zip(log_probs, returns):
            #     loss -= lp * G        
            
            loss = -(log_probs * returns).sum()

        elif RL_ALGO_ARG == 'A2C':
            values  = torch.stack(values[attn_size:])
            advantages = returns.detach() - values
            actor_loss = -(log_probs * advantages.detach()).mean()
            value_loss = advantages.pow(2).mean()
            alpha = cfg["train"]["actor_loss_coef"]
            loss = alpha * actor_loss + (1-alpha) * value_loss

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        #============================
        # 로그
        #============================
        ep_reward = sum(rewards)
        print(ep_reward)

        if tb:
            tb.add_scalar("train/episode_reward", ep_reward, ep)
            tb.add_scalar("train/loss", loss.item(), ep)

        if ep % cfg["train"]["log_interval"] == 0:
            print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

        if ep % cfg["train"]["save_interval"] == 0:
            time_now = dt.now().strftime("%Y-%m-%d-%H:%M")
            os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
            torch.save(policy.state_dict(),
                        os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}_{time_now}.pt"))
        
        os.makedirs(cfg["logging"]["gate_alpha_dir"], exist_ok=True)
        os.makedirs(cfg["logging"]["terminated_dir"], exist_ok=True)
        with open(os.path.join(cfg["logging"]["gate_alpha_dir"], f"gate_alpha_ep{ep}.pkl"), "wb") as file_gate_alpha:
            pickle.dump(gate_alpha_lst, file_gate_alpha)
        with open(os.path.join(cfg["logging"]["terminated_dir"], f"terminated_ep{ep}.pkl"), "wb") as file_terminated:
            pickle.dump(terminated_lst, file_terminated)

        if tb:
            tb.close() 

if __name__ == "__main__":
    RL_ALGO_ARG = sys.argv[1]
    main(RL_ALGO_ARG)

MODULES (5/27)

In [None]:
# class memory_bank():
#     def __init__(
#             self,
#             decay_rate: float=0.1,
#             noise_std: float=0.1,
#             memory_len: int=100,
#             hidden_dim: int=128
#     ):
#         self.memory_idx = 0 ## 시간 순서대로의 memory_idx
#         self.memory_bank_org = []
#         self.memory_bank_hidden = torch.randn(1, memory_len, hidden_dim)
#         self.decay_rate = decay_rate
#         self.noise_std = noise_std 
#         self.memory_len = memory_len
        
#     def update(self, retina, embed_state, hidden_state, action, obs, ep, done):

#         ## 1) push memory_slot & save if necessary
#         self.push(retina, embed_state, hidden_state, action, obs, ep)
#         self.save(ep, done)

#         ## 2) update two memory banks 
#         self.memory_bank_org.append(self.memory_slot)
#         self.memory_bank_hidden = torch.cat([
#             self.memory_bank_hidden,
#             hidden_state 
#         ], axis=1) ## TBD FOR LATER RNN MODIFICATION USED AXIS 1 AGGREGATION
        
#         ## self.memory_bank_hidden = self.memory_bank_hidden[:, -self.memory_len:, :]
#         self.add_decay(self.decay_rate, self.noise_std)
    
#     def truncate(self, timestep, ):
#         if (timestep + 1) % 10 == 0:

        

#     def push(self, retina, embed_state, hidden_state, action, obs, ep):
        
#         self.memory_slot = {
#             'memory_idx': self.memory_idx, 
#             'obs': retina,
#             'embed_state': embed_state,
#             'hidden_state': hidden_state,
#             'action': action,
#             'timestep': obs['timestep'],
#             'position': obs['position'],
#             'epi_no': ep
#         }
#         self.memory_idx += 1 

#     def save(self, ep, done):
#         if done:
#             # os.makedirs('results/memory_bank', exist_ok=True) 
#             with (f'results/memory_bank/{ep}.pkl') as file:
#                 pickle.dump(self.memory_bank_org, file)
#             self.memory_bank_org = []
    
#     def add_decay(self, decay_rate, noise_std): ## TBD EXPONENTIAL DECAY
        
#         memory_len_real = self.memory_bank_hidden.size(1)
#         noise_size = (self.memory_bank_hidden.size(0), 1, self.memory_bank_hidden.size(2))
#         memory_decay_lst = []
#         for idx in range(memory_len_real):
#             memory = self.memory_bank_hidden[:, idx, :].unsqueeze(1)
#             decay_prod = math.exp(decay_rate * (memory_len_real-1-idx))
#             memory += torch.normal(0, noise_std, size=noise_size)*decay_prod
#             memory_decay_lst.append(memory)
#         memory_decay = torch.cat(memory_decay_lst, axis=1)
        
#         return memory_decay

In [158]:
# memory_bank = {}
# for i in range(10):
#     memory_bank[i] = [
#         torch.randn(1,1,128),
#         torch.randn(1,1,20),
#         1, 
#         3.2
#     ]

# torch.cat(list(np.array(list(memory_bank.values()))[:,2]), dim=0)
# torch.tensor(list(np.array(list(memory_bank.values()))[:,3]))

In [431]:
import torch 
import torch.nn as nn
import math
import pickle
import numpy as np

class cnnrnnattn_policy(nn.Module):
    '''
    input: state, hx
    model: CNN > RNN > ATTENTION
    output: log_probs, new_state, state_emd, attention weights
    '''
    def __init__(self, params):
        super().__init__()
        self.cnn_embed_lyr = cnn_embed(**params['cnn_embed'])
        self.rnn_lyr = nn.RNN(**params['rnn']) 
        self.memory_gate_lyr = memory_gate(**params['memory_gate'])
    
    def forward(self, state, hx, memory_bank_set):
        state_emd = self.cnn_embed_lyr(state).unsqueeze(0)
        hidden_emd, _ = self.rnn_lyr(state_emd, hx)
        log_probs, value, attention_weights, chosen_ids, gate_alpha, attention_weights_full = self.memory_gate_lyr(hidden_emd, 
                                                                                    memory_bank_set.hidden_memory,
                                                                                    memory_bank_set.action_memory,
                                                                                    memory_bank_set.reward_memory,
                                                                                    memory_bank_set.memory_ids,
                                                                                    )
        return log_probs, value, state_emd, hidden_emd, chosen_ids, gate_alpha, attention_weights_full

class cnn_embed(nn.Module):
    '''
    input: retina
    output: cnn embedding of retina
    '''
    def __init__(
        self,
        cnn_hidden_lyrs: list = [16, 32],
        lin_hidden_lyrs: list = [32, 64],
        input_img_shape: tuple = (60, 80) # HW
    ):
        ### 0)
        super().__init__()
        
        ### 1) CNN LAYERS
        self.cnn_lyrs = []
        cnn_hidden_lyr_prev = 1 ## TBD iINITIAL CHANNEL SET TO 1
        for cnn_hidden_lyr in cnn_hidden_lyrs:
            conv = nn.Conv2d(in_channels = cnn_hidden_lyr_prev,
                            out_channels = cnn_hidden_lyr,
                            kernel_size=3,
                            stride=1,
                            padding=0)
            self.cnn_lyrs += [conv, nn.ReLU(), nn.MaxPool2d(2,2)]
            cnn_hidden_lyr_prev = cnn_hidden_lyr

        ### 2) LINEAR LAYERS
        self.lin_lyrs = [nn.Flatten()]
        if lin_hidden_lyrs == []:
            pass 
        else:
            output_img_shape = self.get_shape(self.cnn_lyrs, (1,1, *input_img_shape)) ## TBD BATCH
            lin_hidden_lyr_prev = math.prod(output_img_shape)
            for lin_hidden_lyr in lin_hidden_lyrs[:-1]:
                self.lin_lyrs += [
                    nn.Linear(lin_hidden_lyr_prev, lin_hidden_lyr),
                    nn.ReLU()
                ]
                lin_hidden_lyr_prev = lin_hidden_lyr

            self.lin_lyrs.append(nn.Linear(lin_hidden_lyr_prev, lin_hidden_lyrs[-1])) ##TBD CHECK
        
        self.cnn_model = nn.Sequential(*self.cnn_lyrs)
        self.lin_model = nn.Sequential(*self.lin_lyrs)

    def forward(self, input_img, only_output=True):
        inter_img = self.cnn_model(input_img) 
        output_img = self.lin_model(inter_img)

        if only_output:
            return output_img
        else:
            return output_img, inter_img

    def get_shape(self, lyrs, input_img_shape):
        with torch.no_grad():
            embedder = nn.Sequential(*lyrs)
            output = embedder(torch.zeros(*input_img_shape))
            output_img_shape = output.shape[1:]
            return output_img_shape

def gen_topk_random(sbj_tsnr, k, dim, largest):
    random_idx = torch.rand_like(sbj_tsnr).argsort(dim=dim)
    sbj_tsnr_rndm = torch.take_along_dim(sbj_tsnr, random_idx, dim=dim)
    _, k_random_idx = torch.topk(sbj_tsnr_rndm, k=k, largest=largest, dim=dim)
    k_idx = torch.take_along_dim(random_idx, k_random_idx, dim=dim)
    return k_idx
    
class memory_bank():
    def __init__(
            self,
            decay_rate: float=0.0001, 
            noise_std: float=0.001,
            et_lambda: float=0.99,
            memory_len: int=5000,
            update_freq: int=100,
            hidden_dim: int=128,
            decay_yn: bool=False
    ):
        self.decay_rate = decay_rate
        self.noise_std = noise_std
        self.et_lambda = et_lambda 
        self.update_freq = update_freq
        self.memory_len = memory_len
        self.decay_yn = decay_yn

        self.memory_id = 0 ## 시간 순서대로의 memory_id
        self.memory_bank_org = {}
        self.hidden_memory = torch.empty(1, 1, 128)
        self.action_memory = torch.empty(0, dtype=torch.int)
        self.reward_memory = torch.empty(0)
        
        self.memory_ids = torch.empty(0, dtype=torch.int)
        self.trace = torch.empty(0)

    def update(self, retina, embed_state, hidden_state, action, reward, obs, ep, chosen_ids):

        ## 1) push & truncate
        self.push(retina, embed_state, hidden_state, action, reward, obs, ep) ## push (5/28 확인완료)
        self.update_trace(chosen_ids) ## update trace (5/28 확인완료)
        self.trunc() ## truncate (5/28 확인완료)

        ## 2) hard coded everything
        self.hidden_memory = torch.cat([element[3] for element in self.memory_bank_org.values()], dim=1)
        self.action_memory = torch.tensor([element[4] for element in self.memory_bank_org.values()], dtype=torch.int)
        self.reward_memory = torch.Tensor([element[5] for element in self.memory_bank_org.values()])
        # self.hidden_memory = torch.cat(list(np.array(list(self.memory_bank_org.values()))[:,3]), dim=2) ## memory_slot에서 3번째에 해당하는 걸 dim=2 기준으로 concat
        # self.action_memory = torch.tensor(list(np.array(list(self.memory_bank_org.values()))[:,4]), dtype=torch.int)
        # self.reward_memory = torch.tensor(list(np.array(list(self.memory_bank_org.values()))[:,5]))

    def push(self, retina, embed_state, hidden_state, action, reward, obs, ep):
        
        if self.decay_yn:
            hidden_state_decay = self.add_decay(hidden_state)
        else:
            hidden_state_decay = hidden_state

        self.memory_id += 1 ## memory_id 1에서 시작
        
        ## ids 업데이트
        self.memory_ids = torch.cat((self.memory_ids, torch.tensor([self.memory_id], dtype=torch.int)))

        ## memory slot 정의
        self.memory_slot = [
            retina,
            embed_state.detach().clone(), # tensor
            hidden_state.detach().clone(), # tensor
            hidden_state_decay.detach().clone(), # tensor
            action, # int
            reward, # int
            obs['timestep'],
            obs['position'],
            ep
        ]

        ## memory bank 업데이트
        self.memory_bank_org[self.memory_id] = self.memory_slot

    def trunc(self):
        ### memory_bank에 memory_len 이상의 memory가 쌓였을 때 & update_freq에 도달했을 때
        if (self.memory_id > self.memory_len) & (self.memory_id % self.update_freq == 0):     
            ### trace, with more recently added one favored
            # random_idx = torch.rand_like(trace).argsort(dim=0)
            # trace_random = torch.take_along_dim(trace, random_idx, dim=0)
            # _, bottomk_random_idx = torch.topk(trace_random, k=self.update_freq, largest=False, dim=0)
            # bottomk_idx = torch.take_along_dim(random_idx, bottomk_random_idx, dim=0)
            bottomk_idx = gen_topk_random(trace, k=self.update_freq, dim=0, largest=False)
            
            ## memory_bank_org에서 삭제 (5010 > 5000)
            for each_idx in bottomk_idx: 
                each_id = self.memory_ids[each_idx]
                del self.memory_bank_org[each_id]

            ## ids, trace에서 삭제 (5010 > 5000)
            mask = torch.ones(trace.size(0), dtype=torch.bool)
            mask[bottomk_idx] = False

            self.memory_ids = self.memory_ids[mask]
            self.trace = self.trace[mask]

    def update_trace(self, chosen_ids):
        self.trace *= self.et_lambda ## 과거 memory에 대한 trace decay
        self.trace = torch.cat((self.trace, torch.tensor([1]))) ## new memory에 대한 trace 추가
        if chosen_ids.nelement() != 0: #만약 memory가 없어 chosen id도 없다면 
            add_recency = torch.isin(self.memory_ids, chosen_ids).to(dtype=torch.int) ## memory_ids에서 chosen_ids가 있는 자리를 골라냄
            self.trace += add_recency

    def add_decay(self, hidden_state):
        decay_prod = math.exp(self.decay_rate * self.memory_id)
        hidden_state_decay = hidden_state + torch.normal(0, self.noise_std, size=hidden_state.shape)*decay_prod
        return hidden_state_decay

    def save(self, timestep_basedir, attention_base_dir, ep, timestep, chosen_ids):
        os.makedirs(os.path.join(timestep_basedir, f'ep{ep}'), exist_ok=True)
        os.makedirs(os.path.join(attention_base_dir, f'ep{ep}'), exist_ok=True)
        timestep_memory = {self.memory_id : self.memory_slot}
        attn_memory = {chosen.item(): self.memory_bank_org[chosen.item()] for chosen in chosen_ids}
        
        with open(os.path.join(timestep_basedir, f'ep{ep}', f'timestep_memory_{timestep}.pkl'), 'wb') as file1:
            pickle.dump(timestep_memory, file1)
        with open(os.path.join(attention_base_dir, f'ep{ep}', f'attention_memory_{timestep}.pkl'),'wb') as file2:
            pickle.dump(attn_memory, file2)

In [432]:
class memory_gate(nn.Module):
    """
    input: current hidden state, memory_bank_hidden
    output: current hidden state -- attention -- many memory_bank_hidden memorys > outputs action
    """
    def __init__(
            self, 
            hidden_dim_lyrs: list=[128, 64],
            action_dim: int=4,
            attn_size: int=5,
            rl_algo_arg: str='default'
        ):
        
        super().__init__()

        ### 0) ATTENTION WEIGHTS 
        self.hidden_dim = hidden_dim_lyrs[0]
        self.attn_size = attn_size
        self.Q = torch.nn.Parameter(torch.randn(self.hidden_dim,  self.hidden_dim) / math.sqrt(self.hidden_dim)) ## 128 x 128
        self.K = torch.nn.Parameter(torch.randn(self.hidden_dim,  self.hidden_dim) / math.sqrt(self.hidden_dim)) ## 128 x 128
        self.V = torch.nn.Parameter(torch.randn(self.hidden_dim,  self.hidden_dim) / math.sqrt(self.hidden_dim)) ## 128 x 128
        self.gate_alpha = torch.nn.Parameter(torch.zeros(1))
        # self.F = torch.nn.Parameter(torch.randn(hidden_dim,  hidden_dim) / math.sqrt(hidden_dim))

        ### 1) LINEAR LAYERS
        # self.lin_lyrs = [
        #     nn.ReLU(), 
        #     nn.Linear(2*self.hidden_dim, hidden_dim_lyrs[1])
        #     ]
        self.rl_algo_arg = rl_algo_arg
        self.lin_lyrs_actor = []
        if rl_algo_arg == 'A2C':
            self.lin_lyrs_critic = []
        lin_lyr_prev = hidden_dim_lyrs[0] + 2*self.attn_size
        for lin_lyr in hidden_dim_lyrs[1:]:
            self.lin_lyrs_actor += [
                nn.ReLU(),
                nn.Linear(lin_lyr_prev, lin_lyr)
            ]
            if rl_algo_arg == 'A2C':
                self.lin_lyrs_critic += [
                    nn.ReLU(),
                    nn.Linear(lin_lyr_prev, lin_lyr)
                ]
            lin_lyr_prev = lin_lyr

        self.lin_lyrs_actor += [nn.ReLU(), nn.Linear(lin_lyr_prev, action_dim)]
        self.lin_actor = nn.Sequential(*self.lin_lyrs_actor)
        if rl_algo_arg == 'A2C':
            self.lin_lyrs_critic += [nn.ReLU(), nn.Linear(lin_lyr_prev, 1)]
            self.lin_critic = nn.Sequential(*self.lin_lyrs_critic)
            # self.softmax_lyr = nn.Softmax(dim=0)
        
    def forward(self, hidden_state, hidden_memory, action_memory, reward_memory, memory_ids):
        
        value = torch.empty(0)
        if hidden_memory.size(1) >= 5:
            ### 0) Memory_idx_lst update and QUERY, KEY, VALUE CALCULATION
            Query = torch.matmul(hidden_state, self.Q) ## 1 x 1 x 128
            Key = torch.squeeze(torch.matmul(hidden_memory, self.K)) ## 1 x 100 x 128 -- squeezed -- > 100 x 128          
            Value = torch.matmul(hidden_memory, self.V) ## 1 x 100 x 128

            ### 1) ATTENTION
            Wattn = torch.matmul(Query, Key.T) / math.sqrt(self.hidden_dim) ## 1 x 1 x 100
            Wattn = nn.Softmax(dim=2)(Wattn) ## 1 x 1 x 100

            ### added - Top 5 Attentions 
            Wattn_topk_idx = gen_topk_random(Wattn, k=self.attn_size, dim=2, largest=True)
            # Wattn_topk_val, Wattn_topk_idx = torch.topk(Wattn, k=self.attn_size, dim=2) ## 1 x 1 x 5
            Wattn_topk_val = Wattn[:, :, Wattn_topk_idx] 
            Wattn_topk_idx_ = torch.squeeze(Wattn_topk_idx)
            Value_topk_val = Value[:, Wattn_topk_idx_, :] ## 1 x 5 x 128

            chosen_attn = torch.matmul(Wattn_topk_val, Value_topk_val) ## 1 x 1 x 128
            chosen_memory_ids = torch.stack([memory_ids[i] for i in Wattn_topk_idx_])
            chosen_action_feature = action_memory[Wattn_topk_idx_]
            chosen_reward_feature = reward_memory[Wattn_topk_idx_]

            ### 2) OUTPUT ACTIONS
            # if memory_bank_len > memory_bank_hidden.size(1):
            #     new_memory = torch.cat([hidden_state, chosen_action_feature]) ## 1 x 1 x 128
            # else:
            new_memory_hidden = (1-self.gate_alpha)*hidden_state + self.gate_alpha*chosen_attn ## 1 x 1 x 128
            new_memory = torch.cat([torch.squeeze(new_memory_hidden), chosen_action_feature, chosen_reward_feature]) #128 + 5 + 5 

            logits = self.lin_actor(new_memory) ## 4
            if self.rl_algo_arg == 'A2C':
                value = self.lin_critic(new_memory)
                value = torch.squeeze(value)
            else:
                value = torch.empty(0)
            #action_probs = self.softmax_lyr(logits) ## 4

        else: 
            logits = torch.randn(4)
            value = torch.empty(0)
            Wattn_topk_val = torch.empty(1,1,5)
            chosen_memory_ids = torch.empty(0)
            Wattn = torch.empty(1,1,5)

        return logits, value, torch.squeeze(Wattn_topk_val).detach().clone(), chosen_memory_ids, self.gate_alpha.detach().clone(), torch.squeeze(Wattn).detach().clone()
    
'''ANOTHER RETRIEVE FUNCTION - based on COSINE SIMILARITY'''
def retrieve(hidden_state_t, memorys):
    cos_smlr_max = float('-inf')
    memory_len=memorys.size(1)
    for idx in range(memory_len):
        memory = memorys[:, idx, :].unsqueeze(0)
        cos_smlr = nn.CosineSimilarity(dim=2)(hidden_state_t, memory).item()
        if cos_smlr_max < cos_smlr:
            cos_smlr_max = cos_smlr
            memory_retrieve = memory
        return memory_retrieve, cos_smlr_max
    
# class rnn_embed(nn.Module):
#     def __init__(self, 
#                  rnn_input_dim: int,
#                  rnn_hidden_dim: int):
    
#         super().__init__()
#         self.rnn_lyr = nn.RNN(rnn_input_dim, rnn_hidden_dim)
#         self.lin_lyr = nn.Linear(rnn_hidden_dim, rnn_hidden_dim)

#     def forward(self, hidden_state, current_obs):

#         prev_current = torch.cat([hidden_state, current_obs])
#         hidden_state_, _ = self.rnn_lyr(prev_current)
#         hidden_state_new = self.lin_lyr(hidden_state_)
        
#         return hidden_state_new

TRAINING LOOP CHECK (5/28)

In [2]:
### GLOBAL ARGUMENTS
RL_ALGO_ARG = 'REINFORCE' #A2C 

### FROM TRAIN.PY
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
from model.policy import *

torch.autograd.set_detect_anomaly(True)

#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 1) 설정 불러오기
cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
set_seed(cfg["train"]["seed"])

# 2) 환경 생성 & 래핑
#base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
base_env = CustomMazeEnv(**{'layout_id': 'c',
 'goal_pos': [0, 3],
 'view_size': 5,
 'max_steps': 1000,
 'tile_size': 32,
 'render_mode': 'rgb_array'})
env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
action_dim = env.action_space.n

# 3) 에이전트, 옵티마이저
# hyperparameter
HIDDEN_SIZE = HIDDEN_SIZE = cfg["agent"]["hidden_size"]
PARAMS = {
    'memory_bank_ep': {
        'decay_rate': 0.0001, 
        'noise_std': 0.001, 
        'et_lambda': 0.99,
        'memory_len': 5000,
        'update_freq': 100,
        'hidden_dim': HIDDEN_SIZE,
        'decay_yn': False
    },
    'cnn_embed': {
        'cnn_hidden_lyrs': [4, 8],
        'lin_hidden_lyrs': [512, HIDDEN_SIZE],
        'input_img_shape': (60, 80)
    },
    'rnn': {
        'input_size': HIDDEN_SIZE,
        'hidden_size': HIDDEN_SIZE,
        'batch_first': True
    },
    'memory_gate': {
        'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
        'action_dim': 4,
        'attn_size': 5,
        'rl_algo_arg': RL_ALGO_ARG
    }
}

# policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)
policy = cnnrnnattn_policy(PARAMS)
optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))
memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])
hx = torch.randn(1, 1, HIDDEN_SIZE) / math.sqrt(HIDDEN_SIZE)

# tensorboard 준비
if cfg["logging"]["use_tensorboard"]:
    from torch.utils.tensorboard import SummaryWriter
    tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
else:
    tb = None

# 4) 학습 루프
policy.train()
for ep in range(1, 5):
    
    obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
    # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
    retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
    # full_map = obs["image"] # 160 x 160 x 3
    state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

    log_probs, rewards, values = [], [], []
    gate_alpha_lst, terminated_lst = [], []

    #==========================
    # 에피소드 학습
    #==========================
    done = False
    while not done:
        logits, value, sx, hx, chosen_ids, gate_alpha_, attention_ = policy(state, hx, memory_bank_ep)
        m = Categorical(logits=logits)
        a = m.sample()
        log_probs.append(m.log_prob(a))
        values.append(value)
        obs, r, term, trunc, _ = env.step(a.item())
        rewards.append(r)

        memory_bank_ep.update(retina, sx.detach().clone(), hx.detach().clone(), a, r, obs, ep, chosen_ids)
        memory_bank_ep.save(cfg["logging"]["timestep_dir"], cfg["logging"]["attention_dir"], ep, obs['timestep'], chosen_ids)
        
        # next state
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        state = torch.from_numpy(retina)[None, None, ...]
        done = term or trunc

        with torch.no_grad():
            hx = hx.detach().clone()
            gate_alpha_lst.append(gate_alpha_.item())
            terminated_lst.append(term)
            timestep_ = obs['timestep']
            os.makedirs(os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}'), exist_ok=True)
            torch.save(attention_,
                        os.path.join(cfg["logging"]["attention_weight_dir"], f'ep{ep}', f'attention_weight_{timestep_}.pt'))

        ### for checking
        print(chosen_ids, gate_alpha_, r)

    #============================
    # 에피소드 당 손실 계산 (REINFORCE)
    #============================
    attn_size = PARAMS['memory_gate']['attn_size']
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + cfg["agent"]["gamma"] * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    returns = returns[attn_size:]
    log_probs = torch.stack(log_probs[attn_size:])

    if RL_ALGO_ARG == 'REINFORCE':
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        # loss = 0
        # for lp, G in zip(log_probs, returns):
        #     loss -= lp * G        
        
        loss = -(log_probs * returns).sum()

    elif RL_ALGO_ARG == 'A2C':
        values  = torch.stack(values[attn_size:])
        advantages = returns.detach() - values
        actor_loss = -(log_probs * advantages.detach()).mean()
        value_loss = advantages.pow(2).mean()
        alpha = cfg["train"]["actor_loss_coef"]
        loss = alpha * actor_loss + (1-alpha) * value_loss

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

    #============================
    # 로그
    #============================
    ep_reward = sum(rewards)
    print(ep_reward)

    if tb:
        tb.add_scalar("train/episode_reward", ep_reward, ep)
        tb.add_scalar("train/loss", loss.item(), ep)

    if ep % cfg["train"]["log_interval"] == 0:
        print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

    if ep % cfg["train"]["save_interval"] == 0:
        time_now = dt.now().strftime("%Y-%m-%d-%H:%M")
        os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
        torch.save(policy.state_dict(),
                    os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}_{time_now}.pt"))
    
    os.makedirs(cfg["logging"]["gate_alpha_dir"], exist_ok=True)
    os.makedirs(cfg["logging"]["terminated_dir"], exist_ok=True)
    with open(os.path.join(cfg["logging"]["gate_alpha_dir"], f"gate_alpha_ep{ep}.pkl"), "wb") as file_gate_alpha:
        pickle.dump(gate_alpha_lst, file_gate_alpha)
    with open(os.path.join(cfg["logging"]["terminated_dir"], f"terminated_ep{ep}.pkl"), "wb") as file_terminated:
        pickle.dump(terminated_lst, file_terminated)

    if tb:
        tb.close() 

tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([]) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) tensor([0.]) -0.01
tensor([1, 3, 4, 5, 2], dtype=torch.int32) t

In [398]:
# # 4) 학습 루프
# HIDDEN_SIZE = cfg["agent"]["hidden_size"]
# memory_bank_ep = memory_bank(
#     decay_rate=0.1,
#     noise_std=0.1,
#     memory_len=100,
#     hidden_dim=HIDDEN_SIZE
#     )
# ep = 1

# obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
# # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
# retina = reconstruct(obs["image"], render_chanel=1) # 60 x 80
# # full_map = obs["image"] # 160 x 160 x 3

# state = torch.from_numpy(retina).float().view(1,1,-1)
# hx = torch.zeros(1, 1, HIDDEN_SIZE)
# log_probs = []
# rewards = []


# state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80
# hx = torch.zeros(1, 1, cfg["agent"]["hidden_size"])

# ### 0) STATE EMBEDDING
# cnn_embed_lyr = cnn_embed(**PARAMS['cnn_embed'])
# state_emd = cnn_embed_lyr(state).unsqueeze(0) # 1 x 1 x stateDim (128)

# ### 1) AGGREGATE STATE EMBEDDING & HIDDEN STATE > CURRENT HIDDEN STATE
# rnn_lyr = nn.RNN(**PARAMS['rnn'])
# hx, _ = rnn_lyr.forward(state_emd, hx) # 1 x 1 x hiddenDim (128)

# ### 2) AGGREGATE CURRENT HIDDEN STATE & MEMORY > ACTION
# memory_gate_lyr = memory_gate(**PARAMS['memory_gate'])
# log_probs, attention_weights = memory_gate_lyr(hx, memory_bank_ep.memory_bank_hidden)


In [399]:
# HIDDEN_SIZE = cfg["agent"]["hidden_size"]
# PARAMS = {
#     'cnn_embed': {
#         'cnn_hidden_lyrs': [4, 8],
#         'lin_hidden_lyrs': [512, HIDDEN_SIZE],
#         'input_img_shape': (60, 80)
#     },
#     'rnn': {
#         'input_size': HIDDEN_SIZE,
#         'hidden_size': HIDDEN_SIZE,
#         'batch_first': True
#     },
#     'memory_gate': {
#         'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
#         'action_dim': 4
#     }
# }


# policy = cnnrnnattn_policy(PARAMS)
# log_probs, hx, state_emd, attention_weights = policy.forward(state, hx, memory_bank_ep)


Train.py (5/26)

In [None]:
### FROM TRAIN.PY
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display

#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 1) 설정 불러오기
cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
set_seed(cfg["train"]["seed"])

# 2) 환경 생성 & 래핑
#base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
base_env = CustomMazeEnv(**{'layout_id': 'c',
 'goal_pos': [0, 3],
 'view_size': 5,
 'max_steps': 1000,
 'tile_size': 32,
 'render_mode': 'rgb_array'})
env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
action_dim = env.action_space.n

# 3) 에이전트, 옵티마이저
# policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)
HIDDEN_SIZE = cfg["agent"]["hidden_size"]
policy = cnnrnnattn_policy(PARAMS)
optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))
memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])
hx = torch.randn(1, 1, HIDDEN_SIZE) / math.sqrt(HIDDEN_SIZE)

# tensorboard 준비
if cfg["logging"]["use_tensorboard"]:
    from torch.utils.tensorboard import SummaryWriter
    tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
else:
    tb = None

# hyperparameter
PARAMS = {
    'memory_bank_ep': {
        'decay_rate': 0.1, 
        'noise_std': 0.1, 
        'memory_len': 50,
        'hidden_dim': HIDDEN_SIZE
    },
    'cnn_embed': {
        'cnn_hidden_lyrs': [4, 8],
        'lin_hidden_lyrs': [512, HIDDEN_SIZE],
        'input_img_shape': (60, 80)
    },
    'rnn': {
        'input_size': HIDDEN_SIZE,
        'hidden_size': HIDDEN_SIZE,
        'batch_first': True
    },
    'memory_gate': {
        'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
        'action_dim': 4
    }
}

# 4) 학습 루프
for ep in range(1, 2):
    obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
    # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
    retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
    # full_map = obs["image"] # 160 x 160 x 3

    # state = torch.from_numpy(retina).float().view(1,1,-1) #TBD
    # hx = torch.zeros(1, 1, cfg["agent"]["hidden_size"])
    state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

    log_probs = []
    rewards = []

    #==========================
    # 에피소드 학습
    #==========================
    done = False
    while not done:
        logits, hx = policy(state, hx, memory_bank_ep)
        m = Categorical(logits=logits)
        a = m.sample()
        log_probs.append(m.log_prob(a))

        obs, r, term, trunc, _ = env.step(a.item())
        rewards.append(r)

        #======================================
        ##### 메모리에 저장하는 코드 추가해야함 ###
        # retina : retina(egocentric view)
        # action : a
        # hidden_state
        # timestep : obs['timestep']
        # agent_postion : obs['position']
        # episode_number : ep
        #========================================
        memory_bank_ep.update(retina, policy.state_emd, policy.hidden_emd, a, obs, ep, done)

        # next state
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        #state = torch.from_numpy(retina).float().view(1,1,-1)
        state = torch.from_numpy(retina)[None, None, ...]
        done = term or trunc

    #============================
    # 에피소드 당 손실 계산 (REINFORCE)
    #============================
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + cfg["agent"]["gamma"] * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = 0
    for lp, G in zip(log_probs, returns):
        loss -= lp * G

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #============================
    # 로그
    #============================
    ep_reward = sum(rewards)
    if tb:
        tb.add_scalar("train/episode_reward", ep_reward, ep)
        tb.add_scalar("train/loss", loss.item(), ep)

    if ep % cfg["train"]["log_interval"] == 0:
        print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

    if ep % cfg["train"]["save_interval"] == 0:
        os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
        torch.save(policy.state_dict(),
                    os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}.pt"))

    if tb:
        tb.close()

TRAIN.py (5/25)

In [None]:
### FROM TRAIN.PY
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display

#==========================
# 모델 import 및 Load
#==========================

####

#==========================
# 메모리 import 및 Load
#==========================

####


#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 1) 설정 불러오기
cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
set_seed(cfg["train"]["seed"])

# 2) 환경 생성 & 래핑
#base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
base_env = CustomMazeEnv(**{'layout_id': 'c',
 'goal_pos': [0, 3],
 'view_size': 5,
 'max_steps': 1000,
 'tile_size': 32,
 'render_mode': 'rgb_array'})
env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
action_dim = env.action_space.n

# 3) 에이전트, 옵티마이저
# policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)

HIDDEN_SIZE = cfg["agent"]["hidden_size"]
PARAMS = {
    'memory_bank_ep': {
        'decay_rate': 0.1, 
        'noise_std': 0.1, 
        'memory_len': 50,
        'hidden_dim': HIDDEN_SIZE
    },
    'cnn_embed': {
        'cnn_hidden_lyrs': [4, 8],
        'lin_hidden_lyrs': [512, HIDDEN_SIZE],
        'input_img_shape': (60, 80)
    },
    'rnn': {
        'input_size': HIDDEN_SIZE,
        'hidden_size': HIDDEN_SIZE,
        'batch_first': True
    },
    'memory_gate': {
        'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
        'action_dim': 4
    }
}

policy = cnnrnnattn_policy(PARAMS)
optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))

# tensorboard 준비
if cfg["logging"]["use_tensorboard"]:
    from torch.utils.tensorboard import SummaryWriter
    tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
else:
    tb = None
hx = torch.zeros(1, 1, HIDDEN_SIZE)



# 4) 학습 루프
for ep in range(1, 2):
    os.makedirs(open(os.path.join(project_root, f"results/attention/ep{ep}")), exist_ok=True)
    obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
    # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
    retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
    # full_map = obs["image"] # 160 x 160 x 3
    memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])

    # state = torch.from_numpy(retina).float().view(1,1,-1) #TBD
    # hx = torch.zeros(1, 1, cfg["agent"]["hidden_size"])
    state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

    log_probs = []
    rewards = []

    #==========================
    # 에피소드 학습
    #==========================
    done = False
    i = 0
    while not done:
        i += 1
        logits, hx = policy(state, hx, memory_bank_ep)
        m = Categorical(logits=logits)
        a = m.sample()
        log_probs.append(m.log_prob(a))

        obs, r, term, trunc, _ = env.step(a.item())
        rewards.append(r)

        #==========================
        # 이미지 reconstrcut
        #==========================
        retina = reconstruct(obs["image"], render_chanel=1) # 60 x 80

        #======================================
        ##### 메모리에 저장하는 코드 추가해야함 ###
        # retina : retina(egocentric view)
        # action : a
        # hidden_state
        # timestep : obs['timestep']
        # agent_postion : obs['position']
        # episode_number : ep
        #========================================
        memory_bank_ep.update(retina, policy.state_emd, policy.hidden_emd, a, obs, ep, done)
        memory_bank_ep.save()

        # next state
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        #state = torch.from_numpy(retina).float().view(1,1,-1)
        state = torch.from_numpy(retina)[None, None, ...]
        done = term or trunc

    #============================
    # 에피소드 당 손실 계산 (REINFORCE)
    #============================
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + cfg["agent"]["gamma"] * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = 0
    for lp, G in zip(log_probs, returns):
        loss -= lp * G

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #============================
    # 로그
    #============================
    ep_reward = sum(rewards)
    if tb:
        tb.add_scalar("train/episode_reward", ep_reward, ep)
        tb.add_scalar("train/loss", loss.item(), ep)

    if ep % cfg["train"]["log_interval"] == 0:
        print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

    if ep % cfg["train"]["save_interval"] == 0:
        os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
        torch.save(policy.state_dict(),
                    os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}.pt"))

    if tb:
        tb.close()

In [280]:
obs['timestep']

107

In [None]:
with open(os.path.join(project_root, f"results/attention/ep{ep}", f"{obs['timestep']}.pkl"), 'wb') as file:
    pickle.dump()
