In [None]:
import torch 
from torch import nn

In [13]:
vocab = ['I', 'like', 'cats', 'dogs', 'and']
word2idx = {w: i for i,w in enumerate(vocab)}
idx2word = {i: w for i, w in enumerate(vocab)}

In [14]:
sequences = [
    [word2idx[w] for w in ['I','like', 'cats']],
    [word2idx[w] for w in ['I', 'like', 'dogs']],
    [word2idx[w] for w in ['cats', 'and', 'dogs']]
]

x = torch.tensor([seq[:-1] for seq in sequences])
y = torch.tensor([seq[1:] for seq in sequences])

In [15]:
embedding_dim = 4
embed = nn.Embedding(len(vocab), embedding_dim)

In [16]:
hidden_size = 8
output_size = len(vocab)
loss_fn = nn.CrossEntropyLoss()

# RNN

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, *args, **kwargs):
        super().__init__(*args, **kwargs)        
        self.W_x = nn.Linear(input_size, hidden_size, False)        
        self.W_h = nn.Linear(hidden_size, hidden_size, True)
        self.W_y = nn.Linear(hidden_size, output_size, True)
        self.hidden_size = hidden_size
    def forward(self, x, h0=None):
        seq_len, batch_size, _ = x.size()
        if h0 is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else: h = h0
        outputs = []
        for t in range(seq_len):
            xt = x[t]
            h = torch.tanh(self.W_x(xt) + self.W_h(h))
            y = self.W_y(h)
            outputs.append(y.unsqueeze(0)) # seq_len, batch_size, vocab_len
        return torch.cat(outputs, dim=0), h

In [None]:
rnn = RNN(embedding_dim, output_size, hidden_size)
optimizer = torch.optim.Adam(list(rnn.parameters()) + list(embed.parameters()), lr=3e-4)

for epoch in range(501):        
    optimizer.zero_grad()
    x_emb_batch = embed(x).permute(1, 0, 2) #seq_len, batch_size, embedding_dim    
    outputs, _ = rnn(x_emb_batch)    
    outputs = outputs.permute(1, 0, 2) #batch_size, seq_len, vocab_size
        
    loss = loss_fn(outputs.reshape(-1, len(vocab)), y.reshape(-1))        
    loss.backward()
    optimizer.step()    

    if epoch % 50 == 0:
        print(f'Epoch {epoch}: {loss:.4f}')        

Epoch 0: 1.2608
Epoch 50: 1.1929
Epoch 100: 1.1277
Epoch 150: 1.0638
Epoch 200: 1.0016
Epoch 250: 0.9415
Epoch 300: 0.8836
Epoch 350: 0.8283
Epoch 400: 0.7760
Epoch 450: 0.7272
Epoch 500: 0.6820


In [25]:
with torch.no_grad():
    x_test = torch.tensor(
        [[word2idx['I'], word2idx['like']]]
        )
    x_emb_test = embed(x_test).permute(1, 0, 2)
    y_pred, _ = rnn(x_emb_test)
    print(y_pred[-1])
    predicted_idx = y_pred[-1].argmax(dim=-1)
    print(idx2word[predicted_idx.item()])

tensor([[-1.1400,  0.0856,  1.1383,  0.9932,  0.3025]])
cats


# Bidirectional RNN

In [32]:
class BRNN(nn.Module):
    class ForwardRNN(RNN):
        pass
    class BackwardRNN(RNN):
        def forward(self, x, h0=None):
            seq_len, batch_size, _ = x.size()
            if h0 is None:
                h = torch.zeros(batch_size, hidden_size,device=x.device)
            else:
                h = h0
            outputs = []
            for t in range(seq_len - 1, -1, -1):                
                xt= x[t]
                h = torch.tanh(self.W_x(xt) + self.W_h(h))
                y = self.W_y(h)
                outputs.append(y.unsqueeze(0))
            return torch.cat(outputs, dim=0), h
                
            

                        
    def __init__(self, input_size, output_size, hidden_size, *args, **kwargs):
        super().__init__(*args, **kwargs)        
        self.forward_rnn = BRNN.ForwardRNN(input_size, output_size, hidden_size)        
        self.backward_rnn = BRNN.BackwardRNN(input_size, output_size, hidden_size)
        self.out = nn.Linear(2 * output_size, output_size, True)
    def forward(self, x, h0=None):
        f_output, _ = self.forward_rnn(x, h0)
        b_output, _ = self.backward_rnn(x, h0)
        conc_output = torch.cat((f_output, b_output), dim=-1)
        return self.out(conc_output)

In [33]:
brnn = BRNN(embedding_dim, output_size, hidden_size)
optimizer = torch.optim.Adam(list(brnn.parameters()) + list(embed.parameters()))

for epoch in range(501):
    optimizer.zero_grad()
    x_emb_batch = embed(x).permute(1, 0, 2)    
    outputs = brnn(x_emb_batch)
    outputs = outputs.permute(1, 0, 2)    

    loss = loss_fn(outputs.reshape(-1, output_size), y.reshape(-1))
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f'Epoch {epoch}: {loss.item()}')

Epoch 0: 1.7647596597671509
Epoch 50: 1.3231393098831177
Epoch 100: 0.8804828524589539
Epoch 150: 0.5359984636306763
Epoch 200: 0.36696314811706543
Epoch 250: 0.30078092217445374
Epoch 300: 0.2727583050727844
Epoch 350: 0.2587410509586334
Epoch 400: 0.2507844865322113
Epoch 450: 0.24584811925888062
Epoch 500: 0.24257569015026093


In [45]:
with torch.no_grad():
    x_test = torch.tensor(
        [[word2idx['I'], word2idx['like']]]
        )
    x_emb_test = embed(x_test).permute(1, 0, 2)
    y_pred, _ = brnn(x_emb_test)
    print(y_pred[-1])
    predicted_idx = y_pred[-1].argmax(dim=-1)
    print(idx2word[predicted_idx.item()])

tensor([-0.2106,  5.4616, -3.2930, -1.0603,  0.0478])
like


# LSTM 

In [37]:
class LSTM(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, *args, **kwargs):        
        super().__init__(*args, **kwargs)
        self.hidden_size = hidden_size

        self.W_xi = nn.Linear(input_size, hidden_size, False)
        self.W_hi = nn.Linear(hidden_size, hidden_size,True)
        
        self.W_xf = nn.Linear(input_size, hidden_size, False)
        self.W_hf = nn.Linear(hidden_size, hidden_size, True)    

        self.W_xo = nn.Linear(input_size, hidden_size, False)
        self.W_ho = nn.Linear(hidden_size, hidden_size, True)

        self.W_xg = nn.Linear(input_size, hidden_size, False) #candidate
        self.W_hg = nn.Linear(hidden_size, hidden_size, True)

        self.W_y = nn.Linear(hidden_size, output_size, True)
    def forward(self, x, h0=None, c0=None):
        seq_len, batch_size, _ = x.size()
        if h0 is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h = h0
        
        if c0 is None:
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            c = c0

        outputs = []                    
        for t in range(seq_len):
            xt = x[t]

            i = torch.sigmoid(self.W_xi(xt) + self.W_hi(h)) # input gate
            f = torch.sigmoid(self.W_xf(xt) + self.W_hf(h)) # forget gate
            o = torch.sigmoid(self.W_xo(xt) + self.W_ho(h)) # output gate
            g = torch.tanh(self.W_xg(xt) + self.W_hg(h)) # candidate

            c = f * c + i * g
            h = o * torch.tanh(c)

            y = self.W_y(h)
            outputs.append(y.unsqueeze(0))            
        
        return torch.cat(outputs, dim=0), (h, c)

In [44]:
lstm = LSTM(embedding_dim, output_size, hidden_size)
optimizer = torch.optim.Adam(list(lstm.parameters()) + list(embed.parameters()))

for epoch in range(201):
    optimizer.zero_grad()
    x_emb_batch = embed(x).permute(1, 0, 2)    
    outputs, _ = lstm(x_emb_batch)
    outputs = outputs.permute(1, 0, 2)    

    loss = loss_fn(outputs.reshape(-1, output_size), y.reshape(-1))
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f'Epoch {epoch}: {loss.item()}')

Epoch 0: 1.4843062162399292
Epoch 50: 1.3674039840698242
Epoch 100: 1.221854329109192
Epoch 150: 1.0498194694519043
Epoch 200: 0.8678433299064636


In [46]:
with torch.no_grad():
    x_test = torch.tensor(
        [[word2idx['I'], word2idx['like']]]
        )
    x_emb_test = embed(x_test).permute(1, 0, 2)
    y_pred, _ = lstm(x_emb_test)
    print(y_pred[-1])
    predicted_idx = y_pred[-1].argmax(dim=-1)
    print(idx2word[predicted_idx.item()])

tensor([[-0.6963, -0.4010,  0.5341,  1.0510, -0.4065]])
dogs


# GRU

In [51]:
class GRU(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_size = hidden_size

        self.W_xz = nn.Linear(input_size, hidden_size, False)
        self.W_hz = nn.Linear(hidden_size, hidden_size, True)

        self.W_xr = nn.Linear(input_size, hidden_size, False)
        self.W_hr = nn.Linear(hidden_size, hidden_size, True)

        self.W_xh = nn.Linear(input_size, hidden_size, False)
        self.W_hh = nn.Linear(hidden_size, hidden_size, True)

        self.W_y = nn.Linear(hidden_size, output_size, True)        
    
    def forward(self, x, h0=None):
        seq_len, batch_size, _ = x.size()

        if h0 is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)            
        else:
            h= h0
        
        outputs = []
        for t in range(seq_len):
            xt = x[t]

            z = torch.sigmoid(self.W_hz(h) + self.W_xz(xt)) # update gate
            r = torch.sigmoid(self.W_hr(h) + self.W_xr(xt)) # reset gate

            g = torch.tanh(self.W_hh(r * h) + self.W_xh(xt)) # candidate
            h = (1 - z) * h + z * g # hidden state

            y = self.W_y(h)
            outputs.append(y.unsqueeze(0))

        return torch.cat(outputs, dim=0), h    


In [52]:
gru = GRU(embedding_dim, output_size, hidden_size)
optimizer = torch.optim.Adam(list(gru.parameters()) + list(embed.parameters()))

for epoch in range(201):
    optimizer.zero_grad()
    x_emb_batch = embed(x).permute(1, 0, 2)    
    outputs, _ = gru(x_emb_batch)
    outputs = outputs.permute(1, 0, 2)    

    loss = loss_fn(outputs.reshape(-1, output_size), y.reshape(-1))
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f'Epoch {epoch}: {loss.item()}')

Epoch 0: 1.6581777334213257
Epoch 50: 1.4477931261062622
Epoch 100: 1.2707582712173462
Epoch 150: 1.0755815505981445
Epoch 200: 0.8779329657554626


In [53]:
with torch.no_grad():
    x_test = torch.tensor(
        [[word2idx['I'], word2idx['like']]]
        )
    x_emb_test = embed(x_test).permute(1, 0, 2)
    y_pred, _ = gru(x_emb_test)
    print(y_pred[-1])
    predicted_idx = y_pred[-1].argmax(dim=-1)
    print(idx2word[predicted_idx.item()])

tensor([[-0.9891,  0.5437,  1.1459,  0.6621, -0.2743]])
cats
