In [None]:
import compyute as cp

from compyute.functional import zeros, zeros_like
from compyute.nn.funcional import sigmoid
from compyute.nn import Module
from compyute.nn.parameter import Parameter
from compyute.random import uniform
from compyute.tensor import Tensor
from compyute.types import ArrayLike


class LSTMCell(Module):
    """Recurrent cell."""

    def __init__(self, in_channels: int, h_channels: int, use_bias: bool = True, dtype: str = "float32") -> None:
        super().__init__()
        self.in_channels = in_channels
        self.h_channels = h_channels
        self.use_bias = use_bias
        self.dtype = dtype

        k = in_channels**-0.5

        # input gate
        w_i = uniform((h_channels, in_channels), -k, k)
        self.w_i = Parameter(w_i, dtype=dtype, label="w_i")
        u_i = uniform((h_channels, h_channels), -k, k)
        self.u_i = Parameter(u_i, dtype=dtype, label="u_i")
        if use_bias:
            b_i = zeros((h_channels,))
            self.b_i = Parameter(b_i, dtype=dtype, label="b_i")

        # forget gate
        w_f = uniform((h_channels, in_channels), -k, k)
        self.w_f = Parameter(w_f, dtype=dtype, label="w_f")
        u_f = uniform((h_channels, h_channels), -k, k)
        self.u_f = Parameter(u_f, dtype=dtype, label="u_f")
        if use_bias:
            b_f = zeros((h_channels,))
            self.b_f = Parameter(b_f, dtype=dtype, label="b_f")
            
        # output gate
        w_o = uniform((h_channels, in_channels), -k, k)
        self.w_o = Parameter(w_o, dtype=dtype, label="w_o")
        u_o = uniform((h_channels, h_channels), -k, k)
        self.u_o = Parameter(u_o, dtype=dtype, label="u_o")
        if use_bias:
            b_o = zeros((h_channels,))
            self.b_o = Parameter(b_o, dtype=dtype, label="b_o")
            
        # cell
        w_c = uniform((h_channels, in_channels), -k, k)
        self.w_c = Parameter(w_c, dtype=dtype, label="w_c")
        u_c = uniform((h_channels, h_channels), -k, k)
        self.u_c = Parameter(u_c, dtype=dtype, label="u_c")
        if use_bias:
            b_c = zeros((h_channels,))
            self.b_c = Parameter(b_c, dtype=dtype, label="b_c")

    def forward(self, x: Tensor) -> Tensor:
        self.check_dims(x, [3])
        x = x.astype(self.dtype)

        # input projections
        # (B, T, Cin) @ (Cin, Ch) -> (B, T, Ch)
        i_h = x @ self.w_i.T
        f_h = x @ self.w_f.T
        o_h = x @ self.w_i.T
        c_h = x @ self.w_i.T

        if self.use_bias:
            # (B, T, Ch)+ (Ch,) -> (B, T, Ch)
            i_h += self.b_i
            f_h += self.b_f
            o_h += self.b_o
            c_h += self.b_c

        # iterate over timesteps
        i = zeros_like(i_h, dtype=self.dtype, device=self.device)
        f = zeros_like(f_h, dtype=self.dtype, device=self.device)
        o = zeros_like(o_h, dtype=self.dtype, device=self.device)
        c = zeros_like(c_h, dtype=self.dtype, device=self.device)
        h = zeros_like(c_h, dtype=self.dtype, device=self.device)

        for t in range(x.shape[1]):
            i[:, t] = sigmoid(i_h[:, t] + h[:, t - 1] @ self.u_i.T)
            f[:, t] = sigmoid(f_h[:, t] + h[:, t - 1] @ self.u_f.T)
            o[:, t] = sigmoid(o_h[:, t] + h[:, t - 1] @ self.u_o.T)
            c_t_p = (c_h[:, t] + c[:, t - 1] @ self.u_c.T).tanh()
            c[:, t] = f[:, t] * c[:, t - 1] + i[:, t] * c_t_p
            h[:, t] = o[:, t] * c[:, t].tanh()
            
        self.set_y(o)
        return o


In [None]:
        # if self.training:

        #     def backward(dy: ArrayLike) -> ArrayLike:
        #         dh = dy.astype(self.dtype)
        #         self.set_dy(dh)

        #         dx_h = zeros_like(x_h, device=self.device).data
        #         self.w_i.grad = zeros_like(self.w_h, device=self.device).data
        #         self.w_i.grad = zeros_like(self.w_h, device=self.device).data
        #         self.w_i.grad = zeros_like(self.w_h, device=self.device).data
        #         self.w_i.grad = zeros_like(self.w_h, device=self.device).data

        #         for t in range(x.shape[1] - 1, -1, -1):
        #             # add hidden state grad of next t, if not last t
        #             if t == x_h.shape[1] - 1:
        #                 out_grad = dh[:, t]
        #             else:
        #                 out_grad = dh[:, t] + dx_h[:, t + 1] @ self.w_h.T

        #             # activation grads
        #             dx_h[:, t] = (1 - h.data[:, t] ** 2) * out_grad

        #             # hidden weight grads
        #             if t > 0:
        #                 self.w_h.grad += h[:, t - 1].T @ dx_h[:, t]

        #         # hidden bias grads
        #         self.b_h.grad = dx_h.sum((0, 1))

        #         # input grads
        #         dx = dx_h @ self.w_i.T

        #         # input weight grads
        #         dw = x.transpose().data @ dx_h
        #         self.w_i.grad = dw.sum(axis=0)

        #         # input bias grads
        #         self.b_i.grad = dx_h.sum((0, 1))

        #         return dx

        #     self.backward = backward

In [None]:
B, T, Ci, Ch = 2, 3, 4, 5

X = cp.random.normal((B, T, Ci)).float()
lstm = LSTMCell(Ci, Ch)
lstm(X)

In [None]:
import torch

torch_X = torch.tensor(X.to_numpy())
torch_lstm = torch.nn.LSTM(Ci, Ch, 1, batch_first=True)
torch_lstm(torch_X)[0]

In [None]:
torch_lstm.weight_hh_l0.shape

In [None]:
torch_lstm.weight_ih_l0.shape

In [None]:
import torch

In [None]:
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = (x * 2).mean()
y.backward()