<a href="https://colab.research.google.com/github/deguc/Shannon/blob/main/408_QLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
#%%
import numpy as np

def zeros_ps(ps):

    gs = []

    for p in ps:

        gs += [np.zeros_like(p)]

    return gs


class Module:

    def __init__(self):

        self.ps,self.gs = [],[]
        self.train_flag = None


class Linear(Module):

    def __init__(self,d_in,d_out):
        super().__init__()

        std = np.sqrt(2/d_in)
        self.ps = [
            np.random.randn(d_in,d_out)*std,
            np.zeros(d_out)
        ]
        self.gs = zeros_ps(self.ps)
        self.inputs = None

    def __call__(self,x):

        self.inputs = x

        return x @ self.ps[0] + self.ps[1]

    def backward(self,dout):

        self.gs[0][...] = self.inputs.T @ dout
        self.gs[1][...] = np.sum(dout,axis=0)

        return dout @ self.ps[0].T


class ReLU(Module):

    def __init__(self):
        super().__init__()

        self.mask = None

    def __call__(self,x):

        self.mask = x <= 0
        out = x.copy()
        out[self.mask] = 0

        return out

    def backward(self,dout):

        out = dout.copy()
        out[self.mask] = 0

        return out


class Layers:

    def __init__(self,layers):

        self.layers = layers
        self.ps = [[],[]]

        for l in self.layers:
            self.ps[0] += l.ps
            self.ps[1] += l.gs

    def __call__(self,x):

        for l in self.layers:
            x = l(x)

        return x

    def backward(self,dout):

        for l in reversed(self.layers):
            dout = l.backward(dout)


class AdamW:

    def __init__(self,ps,lr=0.01,beta1=0.2,beta2=0.9,weigh_decay=0.1):

        self.ps = ps
        self.cache = (lr,beta1,beta2,weigh_decay)
        self.cnt = 0
        self.hs = [
            zeros_ps(ps[0]),
            zeros_ps(ps[0])
        ]

    def update(self):

        eps = 1e-6
        ps,gs = self.ps
        ms,vs = self.hs
        lr,b1,b2,w = self.cache
        self.cnt += 1

        for p,g,m,v in zip(ps,gs,ms,vs):

            m[...] = b1*m + (1-b1)*g
            v[...] = b2*v + (1-b2)*g*g

            m0 = m /(1-b1**self.cnt)
            v0 = v /(1-b2**self.cnt)

            p -= lr*w*p

            p -= lr*m0/(np.sqrt(v0)+eps)

class GridWorld:

    def __init__(self):

        self.H,self.W = 5,7
        self.action_size = 4
        self.grid = np.zeros((self.H,self.W),dtype=np.int8)
        self.grid[1,1:5] = 1
        self.grid[3,2:6] = 1
        self.grid[4,6] = 2
        self.agent_pos = np.array([0,0])
        self.memory = []

    def render(self):

        legend = np.array(['.','#','G'],dtype='>U1')
        view = legend[self.grid]

        for m in self.memory:
            view[*m] = '*'

        view[*self.agent_pos] = 'A'

        for v in view:
            print(' '.join(v))

        print()

    def step(self,state,action):

        move = np.array([[-1,0],[1,0],[0,-1],[0,1]])
        next_state = state + move[action]

        if not(0 <= next_state[0] < self.H and 0 <= next_state[1] <self.W):
            return state,-1.0,False

        cell = self.grid[*next_state]

        if cell == 1:
            return state,-1.0,False

        if cell == 2:
            return next_state,10,True

        return next_state,-0.1,False

    def update(self,action):

        self.memory += [self.agent_pos]
        next_state,reward,done = self.step(self.agent_pos,action)
        self.agent_pos = next_state

        return next_state,reward,done

    def reset(self,state=np.array([0,0])):

        self.agent_pos = state
        self.memory.clear()

        return state

class QLearning:

    def __init__(self,H,W,action_size,lr=0.01):

        self.H = H
        self.W = W
        self.action_size = action_size

        self.Q = np.zeros((self.H,self.W,self.action_size))
        self.qnet = QNet(self.H,self.W,action_size,lr=lr)
        self.eye = np.eye(H*W,dtype=np.float32)

    def onehot(self,state):

        i,j = state
        idx = i*self.W + j

        return self.eye[idx][np.newaxis,:]

    def pi(self,state):

        state = self.onehot(state)
        return np.argmax(self.qnet(state))

    def get_action(self,state):

        if np.random.rand() < 0.1:
            return np.random.choice(self.action_size)
        else:
            return self.pi(state)


    def update(self,state,action,reward,next_state,done):

        state = self.onehot(state)
        next_state = self.onehot(next_state)

        qs = self.qnet(next_state)
        qmax = np.max(qs)

        dout = np.zeros((1,self.action_size))

        target = reward + (1-int(done))*0.9*qmax
        dout[0,action] = self.qnet(state)[0,action]-target

        self.qnet.backward(dout)
        self.qnet.optimizer.update()


class QNet:

    def __init__(self,H,W,action_size,lr=0.01):

        d_in = H * W
        d_h = 4*d_in
        d_out = action_size

        self.layers = [
            Linear(d_in,d_h),
            ReLU(),
            Linear(d_h,d_out)
        ]
        self.ps = [[],[]]

        for l in self.layers:
            self.ps[0] += l.ps
            self.ps[1] += l.gs

        self.optimizer = AdamW(self.ps,lr=lr)

    def __call__(self,x):

        for l in self.layers:
            x = l(x)

        return x

    def backward(self,dout):

        for l in reversed(self.layers):
            dout = l.backward(dout)


np.set_printoptions(precision=2,suppress=True)

env = GridWorld()
agent = QLearning(env.H,env.W,env.action_size,lr=0.01)

for _ in range(200):
    state = env.reset()

    for _ in range(100):

        action = agent.get_action(state)
        next_state,reward,done = env.update(action)
        agent.update(state,action,reward,next_state,done)

        if done:
            break

        state = next_state

state = env.reset()

for _ in range(20):

    action = agent.pi(state)
    next_state,_,done = env.update(action)

    if done:
        break

    state = next_state

env.render()


* * * * * * .
. # # # # * *
. . . . . . *
. . # # # # *
. . . . . . A

