In [1]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return np.tanh(x)

class GRU:
    def __init__(self, input_dim, hidden_dim):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.Wz = np.random.randn(hidden_dim, hidden_dim + input_dim)
        self.Wr = np.random.randn(hidden_dim, hidden_dim + input_dim)
        self.Wh = np.random.randn(hidden_dim, hidden_dim + input_dim)
        
    def _update_gate(self, wz, xh):
        return sigmoid(np.dot(wz, xh))
    
    def _reset_gate(self, wr, xh):
        return sigmoid(np.dot(wr, xh))
    

    def step(self, x_t, h_prev):
        xh = np.concatenate((h_prev, x_t))
        
        z = self._update_gate(self.Wz, xh)

        r = self._reset_gate(self.Wr, xh)

        xh_prime = np.concatenate((r * h_prev, x_t))
        h_tilde = tanh(np.dot(self.Wh, xh_prime))

        # Final hidden state
        h_next = (1 - z) * h_tilde + z * h_prev

        print("h_next:", h_next)
        return h_next

input_dim = 3
hidden_dim = 5

gru = GRU(input_dim, hidden_dim)

sequence = [np.random.randn(input_dim) for _ in range(6)]

h = np.zeros(hidden_dim)
for x_t in sequence:
    h = gru.step(x_t, h)

print("Final hidden state:", h)

h_next: [ 0.04156905 -0.06337281  0.0766981  -0.34031248  0.7415777 ]
h_next: [ 0.27289368 -0.85467567  0.49969283  0.13101114  0.09204578]
h_next: [ 0.27574468 -0.85553479  0.79735377 -0.13234336  0.947264  ]
h_next: [ 0.28138932 -0.853681    0.64133945 -0.64399339  0.98120636]
h_next: [ 0.2876113  -0.77654789  0.90739334 -0.30594302  0.84444034]
h_next: [ 0.3086645  -0.41402165  0.96442385 -0.1886135   0.9515056 ]
Final hidden state: [ 0.3086645  -0.41402165  0.96442385 -0.1886135   0.9515056 ]
