# LSTM Cell 구현하기

<img src="https://user-images.githubusercontent.com/52481037/100769594-9f972980-343f-11eb-83bc-72d4aae7729d.png" width="600"/>

In [1]:
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
import math

In [2]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) #gate 4개
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) 
        self.c2c = Tensor(hidden_size * 3)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
    
    def forward(self, x, hidden):
        hx, cx = hidden

        x = x.view(-1, x.size(1))
        gates = self.x2h(x) + self.h2h(hx)
    
        c2c = self.c2c.unsqueeze(0)
        ci, cf, co = c2c.chunk(3,1)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        
        ingate = torch.sigmoid(ingate+ ci * cx)
        forgetgate = torch.sigmoid(forgetgate + cf * cx)
        cellgate = forgetgate*cx + ingate* torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate+ co*cellgate)
        

        hm = outgate * torch.tanh(cellgate)
        return hm, (hm, cellgate)

In [3]:
lstm_cell = LSTMCell(50, 32)

inp = torch.randn(1, 50)
hidden = torch.randn(1, 32)
cell = torch.randn(1, 32)

ou, hd = lstm_cell(inp, (hidden, cell))
print(ou.shape, hd[1].shape)

torch.Size([1, 32]) torch.Size([1, 32])
