# GRU cell 구현

<img src="https://user-images.githubusercontent.com/52481037/100781826-e3ddf600-344e-11eb-87bb-2d2254568759.png" width="500"/>


In [1]:
import math
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias) #gate 3개
        self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
    
    def forward(self, x, hidden):
        
        x = x.view(-1, x.size(1))
        
        gate_x = self.x2h(x) 
        gate_h = self.h2h(hidden)
        
        i_r, i_i, i_n = gate_x.chunk(3, 1) #i_r : input for resetgate, i_i : input for inputgate
        h_r, h_i, h_n = gate_h.chunk(3, 1)
        
        
        resetgate = (i_r + h_r).sigmoid()
        inputgate = (i_i + h_i).sigmoid()
        newgate = (i_n + (resetgate * h_n)).tanh()
        
        hy = newgate + inputgate * (hidden - newgate)
            
        return hy

In [3]:
gru_cell = GRUCell(50, 32)


inp = torch.randn(1, 50)
hidden = torch.randn(1, 32)

hd = gru_cell(inp, hidden)
print(hd.shape)

torch.Size([1, 32])
