MODULES

In [None]:
import torch 
import torch.nn as nn
import math
import pickle

class cnnrnnattn_policy(nn.Module):
    '''
    input: state, hx
    model: CNN > RNN > ATTENTION
    output: log_probs, new_state, state_emd, attention weights
    '''
    def __init__(self, params):
        super().__init__()
        self.cnn_embed_lyr = cnn_embed(**params['cnn_embed'])
        self.rnn_lyr = nn.RNN(**params['rnn']) 
        self.memory_gate_lyr = memory_gate(**params['memory_gate'])
    
    def forward(self, state, hx, memory_bank_ep):
        self.state_emd = self.cnn_embed_lyr(state).unsqueeze(0)
        self.hidden_emd, _ = self.rnn_lyr(self.state_emd, hx)
        self.log_probs, self.attention_weights = self.memory_gate_lyr(self.hidden_emd, memory_bank_ep.memory_bank_hidden)

        return self.log_probs, self.hidden_emd
    
class cnn_embed(nn.Module):
    '''
    input: retina
    output: cnn embedding of retina
    '''
    def __init__(
        self,
        cnn_hidden_lyrs: list = [16, 32],
        lin_hidden_lyrs: list = [32, 64],
        input_img_shape: tuple = (60, 80) # HW
    ):
        ### 0)
        super().__init__()
        
        ### 1) CNN LAYERS
        self.cnn_lyrs = []
        cnn_hidden_lyr_prev = 1 ## TBD iINITIAL CHANNEL SET TO 1
        for cnn_hidden_lyr in cnn_hidden_lyrs:
            conv = nn.Conv2d(in_channels = cnn_hidden_lyr_prev,
                            out_channels = cnn_hidden_lyr,
                            kernel_size=3,
                            stride=1,
                            padding=0)
            self.cnn_lyrs += [conv, nn.ReLU(), nn.MaxPool2d(2,2)]
            cnn_hidden_lyr_prev = cnn_hidden_lyr

        ### 2) LINEAR LAYERS
        self.lin_lyrs = [nn.Flatten()]
        if lin_hidden_lyrs == []:
            pass 
        else:
            output_img_shape = self.get_shape(self.cnn_lyrs, (1,1, *input_img_shape)) ## TBD BATCH
            lin_hidden_lyr_prev = math.prod(output_img_shape)
            for lin_hidden_lyr in lin_hidden_lyrs[:-1]:
                self.lin_lyrs += [
                    nn.Linear(lin_hidden_lyr_prev, lin_hidden_lyr),
                    nn.ReLU()
                ]
                lin_hidden_lyr_prev = lin_hidden_lyr

            self.lin_lyrs.append(nn.Linear(lin_hidden_lyr_prev, lin_hidden_lyrs[-1])) ##TBD CHECK

    def forward(self, input_img, only_output=True):
        inter_img = nn.Sequential(*self.cnn_lyrs)(input_img) 
        output_img = nn.Sequential(*self.lin_lyrs)(inter_img)

        if only_output:
            return output_img
        else:
            return output_img, inter_img

    def get_shape(self, lyrs, input_img_shape):
        with torch.no_grad():
            embedder = nn.Sequential(*lyrs)
            output = embedder(torch.zeros(*input_img_shape))
            output_img_shape = output.shape[1:]
            return output_img_shape
    
class memory_bank():
    def __init__(
            self,
            decay_rate: float=0.1,
            noise_std: float=0.1,
            memory_len: int=100,
            hidden_dim: int=128
    ):
        self.memory_bank_org = []
        self.memory_bank_hidden = torch.randn(1, memory_len, hidden_dim)
        self.decay_rate = decay_rate
        self.noise_std = noise_std 
        self.memory_len = memory_len
    
    def update(self, retina, embed_state, hidden_state, action, obs, ep, done):

        ## 1) push memory_slot & save if necessary
        self.push(retina, embed_state, hidden_state, action, obs, ep)
        self.save(ep, done)

        ## 2) update two memory banks 
        self.memory_bank_org.append(self.memory_slot)
        self.memory_bank_hidden = torch.cat([
            self.memory_bank_hidden,
            hidden_state 
        ], axis=1) ## TBD FOR LATER RNN MODIFICATION USED AXIS 1 AGGREGATION
        
        self.memory_bank_hidden = self.memory_bank_hidden[:, -self.memory_len:, :]
        self.add_decay(self.decay_rate, self.noise_std)

        
    def push(self, retina, embed_state, hidden_state, action, obs, ep):
        
        self.memory_slot = {
            'obs': retina,
            'embed_state': embed_state,
            'hidden_state': hidden_state,
            'action': action,
            'timestep': obs['timestep'],
            'position': obs['position'],
            'epi_no': ep
        }

    def save(self, ep, done):
        if done:
            # os.makedirs('results/memory_bank', exist_ok=True) 
            with (f'results/memory_bank/{ep}.pkl') as file:
                pickle.dump(self.memory_bank_org, file)
            self.memory_bank_org = []
    
    def add_decay(self, decay_rate, noise_std): ## TBD EXPONENTIAL DECAY
        
        memory_len_real = self.memory_bank_hidden.size(1)
        noise_size = (self.memory_bank_hidden.size(0), 1, self.memory_bank_hidden.size(2))
        memory_decay_lst = []
        for idx in range(memory_len_real):
            memory = self.memory_bank_hidden[:, idx, :].unsqueeze(1)
            decay_prod = math.exp(decay_rate * (memory_len_real-1-idx))
            memory += torch.normal(0, noise_std, size=noise_size)*decay_prod
            memory_decay_lst.append(memory)
        memory_decay = torch.cat(memory_decay_lst, axis=1)
        
        return memory_decay

class memory_gate(torch.nn.Module):
    """
    input: current hidden state, memory_bank_hidden
    output: current hidden state -- attention -- many memory_bank_hidden memorys > outputs action
    """
    def __init__(
            self, 
            hidden_dim_lyrs: list=[128, 64],
            action_dim: int=4
        ):
        
        super().__init__()

        ### 0) ATTENTION WEIGHTS 
        hidden_dim = hidden_dim_lyrs[0]
        self.Q = torch.nn.Parameter(torch.randn(hidden_dim,  hidden_dim) / math.sqrt(hidden_dim))
        self.K = torch.nn.Parameter(torch.randn(hidden_dim,  hidden_dim) / math.sqrt(hidden_dim))
        self.V = torch.nn.Parameter(torch.randn(hidden_dim,  hidden_dim) / math.sqrt(hidden_dim))
        # self.F = torch.nn.Parameter(torch.randn(hidden_dim,  hidden_dim) / math.sqrt(hidden_dim))

        ### 1) LINEAR LAYERS
        self.lin_lyrs = [
            nn.ReLU(), 
            nn.Linear(2*hidden_dim, hidden_dim_lyrs[1])
            ]
        lin_lyr_prev = hidden_dim_lyrs[1]
        for lin_lyr in hidden_dim_lyrs[2:]:
            self.lin_lyrs += [
                nn.ReLU(),
                nn.Linear(lin_lyr_prev, lin_lyr)
            ]
            lin_lyr_prev = lin_lyr

        self.lin_lyrs += [nn.ReLU(), nn.Linear(lin_lyr_prev, action_dim)]
        
    def forward(self, hidden_state, memory_bank_hidden):
        ### hidden_state: 1 x 1 x 128, memory_bank_hidden 1 x 100 x 128

        ### 0) QUERY, KEY, VALUE CALCULATION
        Query = torch.matmul(hidden_state, self.Q) ## 1 x 1 x 128
        Key = torch.squeeze(torch.matmul(memory_bank_hidden, self.K)) ## 1 x 100 x 128 -- squeezed -- > 100 x 128          
        Value = torch.matmul(memory_bank_hidden, self.V) ## 1 x 100 x 128

        ### 1) ATTENTION
        Wattn = torch.matmul(Query, Key.T) / math.sqrt(d) ## 1 x 1 x 100
        Wattn = nn.Softmax(dim=2)(Wattn) ## 1 x 1 x 100
        attn = torch.matmul(Wattn, Value) ## 1 x 1 x 128

        ### 2) OUTPUT ACTIONS
        new_memory = torch.cat([hidden_state, attn], dim=2) ## 1 x 1 x 256
        logits = nn.Sequential(*self.lin_lyrs)(new_memory) ## 1 x 1 x 4
        action_probs = torch.squeeze(logits) ## 4

        return action_probs, torch.squeeze(Wattn)
    
'''ANOTHER RETRIEVE FUNCTION - based on COSINE SIMILARITY'''
def retrieve(hidden_state_t, memorys):
    cos_smlr_max = float('-inf')
    memory_len=memorys.size(1)
    for idx in range(memory_len):
        memory = memorys[:, idx, :].unsqueeze(0)
        cos_smlr = nn.CosineSimilarity(dim=2)(hidden_state_t, memory).item()
        if cos_smlr_max < cos_smlr:
            cos_smlr_max = cos_smlr
            memory_retrieve = memory
        return memory_retrieve, cos_smlr_max
    

# class rnn_embed(nn.Module):
#     def __init__(self, 
#                  rnn_input_dim: int,
#                  rnn_hidden_dim: int):
    
#         super().__init__()
#         self.rnn_lyr = nn.RNN(rnn_input_dim, rnn_hidden_dim)
#         self.lin_lyr = nn.Linear(rnn_hidden_dim, rnn_hidden_dim)

#     def forward(self, hidden_state, current_obs):

#         prev_current = torch.cat([hidden_state, current_obs])
#         hidden_state_, _ = self.rnn_lyr(prev_current)
#         hidden_state_new = self.lin_lyr(hidden_state_)
        
#         return hidden_state_new

TRAINING LOOP CHECK

In [398]:
# # 4) 학습 루프
# HIDDEN_SIZE = cfg["agent"]["hidden_size"]
# memory_bank_ep = memory_bank(
#     decay_rate=0.1,
#     noise_std=0.1,
#     memory_len=100,
#     hidden_dim=HIDDEN_SIZE
#     )
# ep = 1

# obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
# # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
# retina = reconstruct(obs["image"], render_chanel=1) # 60 x 80
# # full_map = obs["image"] # 160 x 160 x 3

# state = torch.from_numpy(retina).float().view(1,1,-1)
# hx = torch.zeros(1, 1, HIDDEN_SIZE)
# log_probs = []
# rewards = []


# state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80
# hx = torch.zeros(1, 1, cfg["agent"]["hidden_size"])

# ### 0) STATE EMBEDDING
# cnn_embed_lyr = cnn_embed(**PARAMS['cnn_embed'])
# state_emd = cnn_embed_lyr(state).unsqueeze(0) # 1 x 1 x stateDim (128)

# ### 1) AGGREGATE STATE EMBEDDING & HIDDEN STATE > CURRENT HIDDEN STATE
# rnn_lyr = nn.RNN(**PARAMS['rnn'])
# hx, _ = rnn_lyr.forward(state_emd, hx) # 1 x 1 x hiddenDim (128)

# ### 2) AGGREGATE CURRENT HIDDEN STATE & MEMORY > ACTION
# memory_gate_lyr = memory_gate(**PARAMS['memory_gate'])
# log_probs, attention_weights = memory_gate_lyr(hx, memory_bank_ep.memory_bank_hidden)


In [399]:
# HIDDEN_SIZE = cfg["agent"]["hidden_size"]
# PARAMS = {
#     'cnn_embed': {
#         'cnn_hidden_lyrs': [4, 8],
#         'lin_hidden_lyrs': [512, HIDDEN_SIZE],
#         'input_img_shape': (60, 80)
#     },
#     'rnn': {
#         'input_size': HIDDEN_SIZE,
#         'hidden_size': HIDDEN_SIZE,
#         'batch_first': True
#     },
#     'memory_gate': {
#         'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
#         'action_dim': 4
#     }
# }


# policy = cnnrnnattn_policy(PARAMS)
# log_probs, hx, state_emd, attention_weights = policy.forward(state, hx, memory_bank_ep)


TRAIN.py

In [None]:
### FROM TRAIN.PY
import os, sys
# notebooks/에서 한 단계 위로 올라간 폴더를 PATH에 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


import os
import yaml
import torch
import random
import time
from torch import nn, optim
from torch.distributions import Categorical
from env.custom_maze_env import CustomMazeEnv
from env.get_retina_image import reconstruct
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import numpy as np
from IPython import display

#==========================
# 모델 import 및 Load
#==========================

####

#==========================
# 메모리 import 및 Load
#==========================

####


#==========================
# set_seed
#==========================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 1) 설정 불러오기
cfg = yaml.safe_load(open(os.path.join(project_root, "experiments/config/default.yaml"))) ##TBD - ERROR
set_seed(cfg["train"]["seed"])

# 2) 환경 생성 & 래핑
#base_env = CustomMazeEnv(**cfg["env"]) ##TBD - ERROR
base_env = CustomMazeEnv(**{'layout_id': 'c',
 'goal_pos': [0, 3],
 'view_size': 5,
 'max_steps': 1000,
 'tile_size': 32,
 'render_mode': 'rgb_array'})
env = RGBImgPartialObsWrapper(base_env, tile_size=cfg["env"]["tile_size"])
# obs_dim = np.prod(reconstruct(obs["image"], render_chanel=1).shape)  # flatten -> 6ox80
action_dim = env.action_space.n

# 3) 에이전트, 옵티마이저
# policy = RNNPolicy(obs_dim, cfg["agent"]["hidden_size"], action_dim)

HIDDEN_SIZE = cfg["agent"]["hidden_size"]
PARAMS = {
    'memory_bank_ep': {
        'decay_rate': 0.1, 
        'noise_std': 0.1, 
        'memory_len': 50,
        'hidden_dim': HIDDEN_SIZE
    },
    'cnn_embed': {
        'cnn_hidden_lyrs': [4, 8],
        'lin_hidden_lyrs': [512, HIDDEN_SIZE],
        'input_img_shape': (60, 80)
    },
    'rnn': {
        'input_size': HIDDEN_SIZE,
        'hidden_size': HIDDEN_SIZE,
        'batch_first': True
    },
    'memory_gate': {
        'hidden_dim_lyrs': [HIDDEN_SIZE, int(HIDDEN_SIZE/2)],
        'action_dim': 4
    }
}

policy = cnnrnnattn_policy(PARAMS)
optimizer = optim.Adam(policy.parameters(), lr=float(cfg["agent"]["learning_rate"]))

# tensorboard 준비
if cfg["logging"]["use_tensorboard"]:
    from torch.utils.tensorboard import SummaryWriter
    tb = SummaryWriter(cfg["logging"]["tensorboard_dir"])
else:
    tb = None
hx = torch.zeros(1, 1, HIDDEN_SIZE)



# 4) 학습 루프
for ep in range(1, 2):
    obs, _ = env.reset(seed=cfg["train"]["seed"] + ep)
    # obs["image"] shape = (tile_size * view_size, tile_size * view_size, 3)
    retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
    # full_map = obs["image"] # 160 x 160 x 3
    memory_bank_ep = memory_bank(**PARAMS['memory_bank_ep'])

    # state = torch.from_numpy(retina).float().view(1,1,-1) #TBD
    # hx = torch.zeros(1, 1, cfg["agent"]["hidden_size"])
    state = torch.from_numpy(retina)[None, None, ...] ## 1 x 1 x 60 x 80

    log_probs = []
    rewards = []

    #==========================
    # 에피소드 학습
    #==========================
    done = False
    while not done:
        logits, hx = policy(state, hx, memory_bank_ep)
        m = Categorical(logits=logits)
        a = m.sample()
        log_probs.append(m.log_prob(a))

        obs, r, term, trunc, _ = env.step(a.item())
        rewards.append(r)

        #==========================
        # 이미지 reconstrcut
        #==========================
        retina = reconstruct(obs["image"], render_chanel=1) # 60 x 80

        #======================================
        ##### 메모리에 저장하는 코드 추가해야함 ###
        # retina : retina(egocentric view)
        # action : a
        # hidden_state
        # timestep : obs['timestep']
        # agent_postion : obs['position']
        # episode_number : ep
        #========================================
        memory_bank_ep.update(retina, policy.state_emd, policy.hidden_emd, a, obs, ep, done)

        # next state
        retina = reconstruct(obs["image"], render_chanel=1)# 60 x 80
        #state = torch.from_numpy(retina).float().view(1,1,-1)
        state = torch.from_numpy(retina)[None, None, ...]
        done = term or trunc

    #============================
    # 에피소드 당 손실 계산 (REINFORCE)
    #============================
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + cfg["agent"]["gamma"] * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = 0
    for lp, G in zip(log_probs, returns):
        loss -= lp * G

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #============================
    # 로그
    #============================
    ep_reward = sum(rewards)
    if tb:
        tb.add_scalar("train/episode_reward", ep_reward, ep)
        tb.add_scalar("train/loss", loss.item(), ep)

    if ep % cfg["train"]["log_interval"] == 0:
        print(f"[Episode {ep}] reward={ep_reward:.2f}, loss={loss.item():.4f}")

    if ep % cfg["train"]["save_interval"] == 0:
        os.makedirs(cfg["train"]["checkpoint_dir"], exist_ok=True)
        torch.save(policy.state_dict(),
                    os.path.join(cfg["train"]["checkpoint_dir"], f"policy_ep{ep}.pt"))

    if tb:
        tb.close()