In [None]:
from compyute.functional import zeros, zeros_like
from compyute.nn.module import Module
from compyute.nn.parameter import Parameter
from compyute.random import uniform
from compyute.tensor import Tensor
from compyute.types import ArrayLike


class RecurrentCell(Module):
    """Recurrent cell."""

    def __init__(self, in_channels: int, h_channels: int, use_bias: bool = True, dtype: str = "float32") -> None:
        super().__init__()
        self.in_channels = in_channels
        self.h_channels = h_channels
        self.use_bias = use_bias
        self.dtype = dtype

        k = in_channels**-0.5

        # input gate
        w_i = uniform((in_channels, h_channels), -k, k)
        self.w_i = Parameter(w_i, dtype=dtype, label="w_i")
        u_i = uniform((in_channels, h_channels), -k, k)
        self.u_i = Parameter(u_i, dtype=dtype, label="u_i")
        if use_bias:
            b_i = zeros((h_channels,))
            self.b_i = Parameter(b_i, dtype=dtype, label="b_i")

        # forget gate
        w_f = uniform((in_channels, h_channels), -k, k)
        self.w_f = Parameter(w_f, dtype=dtype, label="w_f")
        u_f = uniform((in_channels, h_channels), -k, k)
        self.u_f = Parameter(u_f, dtype=dtype, label="u_f")
        if use_bias:
            b_f = zeros((h_channels,))
            self.b_f = Parameter(b_f, dtype=dtype, label="b_f")
            
        # output gate
        w_o = uniform((in_channels, h_channels), -k, k)
        self.w_o = Parameter(w_o, dtype=dtype, label="w_o")
        u_o = uniform((in_channels, h_channels), -k, k)
        self.u_o = Parameter(u_o, dtype=dtype, label="u_o")
        if use_bias:
            b_o = zeros((h_channels,))
            self.b_o = Parameter(b_o, dtype=dtype, label="b_o")
            
        # cell
        w_c = uniform((in_channels, h_channels), -k, k)
        self.w_c = Parameter(w_c, dtype=dtype, label="w_c")
        u_c = uniform((in_channels, h_channels), -k, k)
        self.u_c = Parameter(u_c, dtype=dtype, label="u_c")
        if use_bias:
            b_c = zeros((h_channels,))
            self.b_c = Parameter(b_c, dtype=dtype, label="b_c")

    def forward(self, x: Tensor) -> Tensor:
        self.check_dims(x, [3])
        x = x.astype(self.dtype)

        # input gate
        f_t_h = 

        # iterate over timesteps
        h = zeros_like(x_h, device=self.device)
        for t in range(x_h.shape[1]):
            h_t = h[:, t - 1] @ self.w_h
            if self.use_bias:
                h_t += self.b_h

            # activation
            h[:, t] = (x_h[:, t] + h_t).tanh()

        if self.training:

            def backward(dy: ArrayLike) -> ArrayLike:
                dh = dy.astype(self.dtype)
                self.set_dy(dh)

                dx_h = zeros_like(x_h, device=self.device).data
                self.w_h.grad = zeros_like(self.w_h, device=self.device).data

                for t in range(x.shape[1] - 1, -1, -1):
                    # add hidden state grad of next t, if not last t
                    if t == x_h.shape[1] - 1:
                        out_grad = dh[:, t]
                    else:
                        out_grad = dh[:, t] + dx_h[:, t + 1] @ self.w_h.T

                    # activation grads
                    dx_h[:, t] = (1 - h.data[:, t] ** 2) * out_grad

                    # hidden weight grads
                    if t > 0:
                        self.w_h.grad += h[:, t - 1].T @ dx_h[:, t]

                # hidden bias grads
                self.b_h.grad = dx_h.sum((0, 1))

                # input grads
                dx = dx_h @ self.w_i.T

                # input weight grads
                dw = x.transpose().data @ dx_h
                self.w_i.grad = dw.sum(axis=0)

                # input bias grads
                self.b_i.grad = dx_h.sum((0, 1))

                return dx

            self.backward = backward

        self.set_y(h)
        return h
