# GRU

GRU（Gated Recurrent Unit）在LSTM进行了改进，结构比LSTM更加精简，比LSTM少了一个Gate，计算效率更高占用内存更小

1）GRU将输入门、遗忘门、输出门3个门变为2个：更新门（Update Gate）和重置门（Reset Gate）  
2）将单元状态与输出合并成一个状态

![](./pics/GRU_model.png)

参考：《Python深度学习基于PyTroch》

## 1 Pytorch实现

GRU和LSTM很相似，只是LSTM有3个门2个隐含状态，GRU只有2个门1个隐含状态

In [2]:
import torch
import torch.nn as nn

In [12]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUCell, self).__init__()
        self.hidden_size = hidden_size
        # self.cell_size = cell_size # 没有LSTM的cell单元状态
        # self.gate = nn.Linear(input_size + hidden_size, cell_size)
        self.gate = nn.Linear(input_size + hidden_size, hidden_size) # 从cell_size变为hidden_size
        self.output = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim=1)
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        z_gate = self.sigmoid(self.gate(combined)) # 更新门 没有输入门
        r_gate = self.sigmoid(self.gate(combined)) # 重置门 没有遗忘门
        combined01 = torch.cat((input, torch.mul(hidden, r_gate)), 1) # 多了组合
        # o_gate = self.sigmoid(self.gate(combined)) # 没有输出门
        h1_state = self.tanh(self.gate(combined)) # 从z_state变为h1_state
        # cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate)) # 没有计算单元
        h_state = torch.add(torch.mul((1-z_gate), hidden), torch.mul(h1_state, z_gate)) # 多了计算状态
        output = self.output(hidden) 
        output = self.softmax(output)
        return output, h_state
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

In [14]:
gru_cell = GRUCell(input_size=10, hidden_size=20, output_size=10)

In [15]:
input = torch.randn(32, 10)
h_0 = torch.randn(32, 20)

In [18]:
output, hn  = gru_cell(input, h_0)

In [19]:
output.size(), hn.size()

(torch.Size([32, 10]), torch.Size([32, 20]))