In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from d2l import torch as d2l

class GRUModel(nn.Module):
    def __init__(self, num_feature, hidden_size, device):
        super().__init__()
        self.reset_gate_params = self.get_parmas(num_feature, hidden_size, device)
        self.update_gate_params = self.get_parmas(num_feature, hidden_size, device)
        self.candidate_params = self.get_parmas(num_feature, hidden_size, device)
        self.out_linear = nn.Linear(hidden_size, num_feature, device=device)
        self.num_feature = num_feature
        self.hidden_size = hidden_size
        self.device = device

    def get_parmas(self, num_feature, hidden_size, device):
        W_hx = torch.normal(0, 0.01, (num_feature, hidden_size), requires_grad=True, device=device)
        W_hh = torch.normal(0, 0.01, (hidden_size, hidden_size), requires_grad=True, device=device)
        b_x = torch.zeros(hidden_size, requires_grad=True, device=device)
        b_h = torch.zeros(hidden_size, requires_grad=True, device=device)
        return (W_hx, W_hh, b_x, b_h)
    
    def params(self):
        return self.reset_gate_params + self.update_gate_params + self.candidate_params
    
    def weighted_sum(self, X, H, params):
        # X=>[batch_size, num_feature]
        # H=>[batch_size, hidden_size]
        # params[0]=>[num_feature, hidden_size]
        # params[1]=>[hidden_size]
        # params[2]=>[hidden_size, hidden_size]
        # params[3]=>[hidden_size]
        
        return torch.matmul(X, params[0]) + params[2] + torch.matmul(H, params[1]) + params[3]

    def forward(self, X, H): # X: [batch_size, num_step], H: [batch_size, hidden_size]
        if H is None:
            H = torch.zeros((X.shape[0], self.hidden_size), requires_grad=True, device=self.device)
        else:
            H.detach_()
        # X => [batch_size, num_step] => [batch_size, num_step, num_feature] => [num_step, batch_size, num_feature]
        X = F.one_hot(X, self.num_feature).type(torch.float32).permute([1, 0, 2])

        # x => [batch_size, num_feature]
        Y = []
        for x in X:
            reset = F.sigmoid( self.weighted_sum(x, H, self.reset_gate_params) )
            update = F.sigmoid( self.weighted_sum(x, H, self.update_gate_params) )
            h_candidate = F.tanh( torch.matmul(x,self.candidate_params[0]) + self.candidate_params[2] + torch.matmul(reset * H, self.candidate_params[1]) + self.candidate_params[3])
            H = update * H + (1-update) * h_candidate
            y = self.out_linear(H)
            Y.append(y)
        return torch.stack(Y, dim=0).permute([1, 0, 2]), H

In [8]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

batch_size, num_steps, use_random_iter = 32, 35, True
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps, use_random_iter)
epochs, lr , num_hidden, num_feature = 500, 1, 256, len(vocab)

device = 'cpu'
net = GRUModel(num_feature, num_hidden, device=device)
updater = optim.SGD(net.params(), lr=lr)
criteria = nn.CrossEntropyLoss()

step = 0
bar = tqdm(range(epochs))
for epoch in bar:
    H = None
    for x, y in train_iter:
        y_hat, H = net(x.to(device), H)
        loss = criteria(y_hat.reshape(-1, num_feature), y.reshape(-1).to(device))
        updater.zero_grad()
        loss.backward()

        d2l.grad_clipping(net, 1)
        updater.step()

        step += 1
        #writer.add_scalar("loss", torch.exp(loss.detach().to('cpu')).item(), step)
        bar.set_postfix_str(str(torch.exp(loss.detach())))

100%|██████████| 500/500 [02:10<00:00,  3.84it/s, tensor(5.4520)] 
