## 1. Minimalize LSTM

### 1.1 LSTM cell

In [None]:
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.linear_ih = nn.Linear(input_size, 4 * hidden_size)
        self.linear_hh = nn.Linear(hidden_size, 4 * hidden_size)

    def forward(self, x_t, h_prev, c_prev):
        # 게이트 계산: 입력-은닉 가중치 합
        gates = self.linear_ih(x_t) + self.linear_hh(h_prev)

        # i, f, g, o 게이트 분할
        i_t, f_t, g_t, o_t = torch.chunk(gates, 4, dim=1)

        # 활성화
        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        g_t = torch.tanh(g_t)
        o_t = torch.sigmoid(o_t)

        # 셀 상태와 은닉 상태 업데이트
        c_t = f_t * c_prev + i_t * g_t
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t



### 1.2 Minimalizing Step 1
- 출력 게이트 제거

In [None]:
class MinLSTMCell_Step1(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # 이제 3 * hidden_size만 필요
        self.linear_ih = nn.Linear(input_size, 3 * hidden_size)
        self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size)

    def forward(self, x_t, h_prev, c_prev):
        gates = self.linear_ih(x_t) + self.linear_hh(h_prev)
        i_t, f_t, g_t = torch.chunk(gates, 3, dim=1)

        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        g_t = torch.tanh(g_t)

        c_t = f_t * c_prev + i_t * g_t
        h_t = torch.tanh(c_t)  # 출력 게이트 제거

        return h_t, c_t

### 1.3 Minimalizing Step 2
- 이전 은닉 상태 의존성 $h_{t-1}$ 제거

In [4]:
class MinLSTMCell_Step2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # linear_hh 제거, linear_ih는 3 * hidden_size 유지
        self.linear_ih = nn.Linear(input_size, 3 * hidden_size)

    def forward(self, x_t, c_prev):
        gates = self.linear_ih(x_t)  # h_prev 의존성 제거
        i_t, f_t, g_t = torch.chunk(gates, 3, dim=1)

        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        g_t = torch.tanh(g_t)

        c_t = f_t * c_prev + i_t * g_t
        h_t = torch.tanh(c_t)

        return h_t, c_t

### Minimalizing Step 3
- Input Gate와 Forget Gate 통합

In [None]:
class MinLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.linear_ih = nn.Linear(input_size, 3 * hidden_size)

    def forward(self, x_t, h_prev, c_prev):
        gates = self.linear_ih(x_t)
        i_raw, f_raw, g_t = torch.chunk(gates, 3, dim=1)

        i_t = torch.sigmoid(i_raw)
        f_t = torch.sigmoid(f_raw)
        sum_gates = f_t + i_t + 1e-10
        f_prime = f_t / sum_gates
        i_prime = i_t / sum_gates

        c_t = f_prime * c_prev + i_prime * g_t
        h_t = torch.tanh(c_t)

        return h_t, c_t

## 2. Minimalize GRU