<a href="https://colab.research.google.com/github/deguc/Shannon/blob/main/412_A2C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

def zeros_ps(ps):

    gs = []

    for p in ps:

        gs.append(np.zeros_like(p))

    return gs


class Module:

    def __init__(self):

        self.ps,self.gs = [],[]
        self.train_flag = None


class Linear(Module):

    def __init__(self,d_in,d_out):
        super().__init__()

        std = np.sqrt(2/d_in)
        self.ps = [
            np.random.randn(d_in,d_out)*std,
            np.zeros(d_out)
        ]
        self.gs = zeros_ps(self.ps)
        self.inputs = None

    def __call__(self,x):

        self.inputs = x

        return x @ self.ps[0] + self.ps[1]

    def backward(self,dout):

        self.gs[0][...] = self.inputs.T @ dout
        self.gs[1][...] = np.sum(dout,axis=0)

        return dout @ self.ps[0].T


class ReLU(Module):

    def __init__(self):
        super().__init__()

        self.mask = None

    def __call__(self,x):

        self.mask = x <= 0
        out = x.copy()
        out[self.mask] = 0

        return out

    def backward(self,dout):

        out = dout.copy()
        out[self.mask] = 0

        return out


class Layers:

    def __init__(self,layers):

        self.layers = layers
        self.ps = [[],[]]

        for l in self.layers:
            self.ps[0].extend(l.ps)
            self.ps[1].extend(l.ps)

    def __call__(self,x):

        for l in self.layers:
            x = l(x)

        return x

    def backward(self,dout):

        for l in reversed(self.layers):
            dout = l.backward(dout)


def softmax(x,temp=1.0):

    c = np.max(x,axis=-1,keepdims=True)
    z = np.exp((x-c)/temp)

    return z / np.sum(z,axis=-1,keepdims=True)


class AdamW:

    def __init__(self,ps,lr=0.01,beta1=0.9,beta2=0.999,weight_decay=1e-4):

        self.ps = ps
        self.cache = (lr,beta1,beta2,weight_decay)
        self.cnt = 0
        self.hs = [
            zeros_ps(ps[0]),
            zeros_ps(ps[0])
        ]

    def update(self):

        eps = 1e-6
        ps,gs = self.ps
        ms,vs = self.hs
        lr,b1,b2,w = self.cache
        self.cnt += 1

        for p,g,m,v in zip(ps,gs,ms,vs):

            m[...] = b1*m + (1-b1)*g
            v[...] = b2*v + (1-b2)*g*g

            m0 = m /(1-b1**self.cnt)
            v0 = v /(1-b2**self.cnt)

            p -= lr*w*p

            p -= lr*m0/(np.sqrt(v0)+eps)


class GridWorld:

    def __init__(self):

        self.H,self.W = 5,7
        self.action_size = 4
        self.grid = np.zeros((self.H,self.W),dtype=np.int8)
        self.grid[1,1:5] = 1
        self.grid[3,2:6] = 1
        self.grid[4,6] = 2
        self.agent_pos = np.array([0,0])
        self.memory = []

    def render(self):

        legend = np.array(['.','#','G'],dtype='>U1')
        view = legend[self.grid]

        for m in self.memory:
            view[tuple(m)] = '*'

        view[tuple(self.agent_pos)] = 'A'

        for v in view:
            print(' '.join(v))

        print()

    def onestep(self,state,action):

        move = np.array([[-1,0],[1,0],[0,-1],[0,1]])
        next_state = state + move[action]

        if not(0 <= next_state[0] < self.H and 0 <= next_state[1] <self.W):
            return state,-0.1,False

        cell = self.grid[tuple(next_state)]

        if cell == 1:
            return state,-0.1,False

        if cell == 2:
            return next_state,10,True

        return next_state,-0.01,False

    def step(self,action):

        self.memory += [self.agent_pos]
        next_state,reward,done = self.onestep(self.agent_pos,action)
        self.agent_pos = next_state

        return next_state,reward,done

    def reset(self,state=None):

        if state is None:
            state = np.array([0,0])

        self.agent_pos = state
        self.memory.clear()

        return state


class PolicyNet:

    def __init__(self,H,W,action_size,lr=0.01):

        d_in = H * W
        d_h = 4*d_in
        d_out = action_size

        self.layers = [
            Linear(d_in,d_h),
            ReLU(),
            Linear(d_h,d_out)
        ]
        self.ps = [[],[]]

        for l in self.layers:
            self.ps[0] += l.ps
            self.ps[1] += l.gs

        self.optimizer = AdamW(self.ps,lr=lr)

    def __call__(self,x):

        for l in self.layers:
            x = l(x)

        return x

    def backward(self,dout):

        for l in reversed(self.layers):
            dout = l.backward(dout)


class ValueNet:

    def __init__(self,H,W,lr=0.01):

        d_in = H * W
        d_h = 4*d_in
        d_out = 1

        self.layers = [
            Linear(d_in,d_h),
            ReLU(),
            Linear(d_h,d_out)
        ]
        self.ps = [[],[]]

        for l in self.layers:
            self.ps[0] += l.ps
            self.ps[1] += l.gs

        self.optimizer = AdamW(self.ps,lr=lr)

    def __call__(self,x):

        for l in self.layers:
            x = l(x)

        return x

    def backward(self,dout):

        for l in reversed(self.layers):
            dout = l.backward(dout)


class A2C:

    def __init__(self,H,W,action_size,actor_lr=1e-3,critic_lr=1e-2):

        self.H = H
        self.W = W
        self.action_size = action_size
        self.gamma = 0.9

        self.pi = PolicyNet(self.H,self.W,action_size,lr=actor_lr)
        self.v = ValueNet(self.H,self.W,lr=critic_lr)

        self.eye = np.eye(H*W,dtype=np.float32)

        self.trajectory = []

    def onehot(self,states):

        i = states[:,0]
        j = states[:,1]

        idx = i*self.W + j

        return self.eye[idx]

    def reset(self):

        self.trajectory.clear()

    def memory(self,state,action,reward):

        data = (self.onehot(state),action,reward)
        self.trajectory.append(data)

    def get_action(self,states,greedy=False):

        states = self.onehot(states)
        logits = self.pi(states)
        probs = softmax(logits)

        if greedy:
            return np.argmax(probs,axis=1).astype(int)

        actions = np.array([np.random.choice(self.action_size,p=probs[i])
                   for i in range(probs.shape[0])],dtype=int)

        return actions


    def update(self,states,actions,returns):

        x = self.onehot(states)

        V = self.v(x).reshape(-1)
        adv = (returns-V).astype(np.float32)/len(actions)
        adv = np.clip(adv,-5,5)

        dV = (V-returns).astype(np.float32).reshape(-1,1)
        self.v.backward(dV)
        self.v.optimizer.update()

        logits = self.pi(x)
        probs = softmax(logits,temp=1.0).astype(np.float32)


        dout = probs.copy()
        dout[np.arange(len(actions)),actions] -= 1.0
        dout *= adv[:,None]
        dout /= len(actions)

        self.pi.backward(dout)
        self.pi.optimizer.update()


np.set_printoptions(precision=2,suppress=True)

env = GridWorld()
agent = A2C(env.H,env.W,env.action_size,actor_lr=1e-3,critic_lr=1e-2)

N = 16
T = 10
envs = [GridWorld() for _ in range(N)]
states = np.stack([e.reset() for e in envs],axis=0).astype(int)

for upd in range(200):

    buf_states = np.zeros((T,N,2),dtype = int)
    buf_actions = np.zeros((T,N),dtype = int)
    buf_rewards = np.zeros((T,N),dtype = np.float32)
    buf_dones = np.zeros((T,N),dtype = np.float32)

    for t in range(T):
        actions = agent.get_action(states)
        next_states = np.zeros_like(states)
        rewards = np.zeros((N,),dtype=np.float32)
        dones = np.zeros((N,),dtype=np.float32)

        for i,e in enumerate(envs):
            ns,r,d = e.step(int(actions[i]))
            next_states[i] = ns
            rewards[i] = r
            dones[i] = 1.0 if d else 0.0
            if d:
                next_states[i] = e.reset()

        buf_states[t] = states
        buf_actions[t] = actions
        buf_rewards[t] = rewards
        buf_dones[t] = dones

        states = next_states

    x_last = agent.onehot(states)
    v_last = agent.v(x_last).reshape(-1).astype(np.float32)

    returns = np.zeros((T,N),dtype=np.float32)
    R = v_last.copy()
    for t in reversed(range(T)):
        R = buf_rewards[t] + (1-buf_dones[t])*agent.gamma*R
        returns[t] = R

    B_states = buf_states.reshape(T*N,2)
    B_actions = buf_actions.reshape(T*N)
    B_returns = returns.reshape(T*N)

    agent.update(B_states,B_actions,B_returns)

e = GridWorld()
s = e.reset()

for _ in range(20):

    a = agent.get_action(np.array([s],dtype=int),greedy=True)
    s,r,d = e.step(int(a[0]))

    if d:
        break

e.render()