In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [2]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [3]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [4]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [5]:
#word = 'ololoasdasddqweqw123456789'
word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [6]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [7]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [8]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.1, momentum=0.9)

# Обучение

In [9]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

6.128508567810059
Clip gradient :  2.287575206371895
1.8400852680206299
Clip gradient :  1.4283505115113067
0.030864715576171875
Clip gradient :  0.04848498980189007
0.005702495574951172
Clip gradient :  0.011568606113990224
0.0030508041381835938
Clip gradient :  0.006323590722674666
0.0022988319396972656
Clip gradient :  0.0034661851186251038
0.0020341873168945312
Clip gradient :  0.0029335735564451535
0.0018992424011230469
Clip gradient :  0.0027512204360319857
0.00180816650390625
Clip gradient :  0.00258813861116277
0.001735687255859375
Clip gradient :  0.002443450412496375


# Тестирование

In [10]:
### rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id =0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 hello
Original:	 hello


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово

In [11]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [12]:
#Написать реализацию LSTM и обучить предсказывать слово

class LSTM(nn.Module):
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(LSTM, self).__init__()
        self.tg = nn.Tanh()
        
        self.sigm = nn.Sigmoid()
        
        self.hid_c_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_c_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.hid_i_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_i_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.hid_f_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_f_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.hid_o_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_o_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.out = nn.Linear(in_features=hidden_size, out_features=out_size)   
        
    def forward(self, x, prev_hidden, c_prev):
        
        c_1 = self.tg(self.hid_c_1(x) + self.hid_c_2(prev_hidden))
        
        i = self.sigm(self.hid_i_1(x) + self.hid_i_2(prev_hidden))
        
        f = self.sigm(self.hid_f_1(x) + self.hid_f_2(prev_hidden))
        
        o = self.sigm(self.hid_o_1(x) + self.hid_o_2(prev_hidden))
        
        c = f * c_prev + i * c_1
        
        h = o * self.tg(c_1)
        
        output = self.out(h)
        return output, h, c
        
        

## Инициализируем переменные

In [13]:
ds = WordDataSet(word=word)
rnn = LSTM(in_size=ds.vec_size, hidden_size=9, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 150
optim     = SGD(rnn.parameters(), lr = 0.1, momentum=0.9)

## Обучаем LSTM

In [14]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hid_c_2.in_features)
    cc = torch.zeros(rnn.hid_c_2.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh, cc = rnn(x, hh, cc)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

71.44445037841797
Clip gradient :  3.4239971525674724
61.82365798950195
Clip gradient :  5.2059090300201
31.809295654296875
Clip gradient :  10.128291933534362
13.372635841369629
Clip gradient :  3.9085528295780834
5.584018707275391
Clip gradient :  1.9610680996020962
3.6011781692504883
Clip gradient :  17.21535204069431
4.612643718719482
Clip gradient :  4.20854768935105
3.5759267807006836
Clip gradient :  2.9361500243323833
2.1023778915405273
Clip gradient :  0.8815032780740604
1.6403532028198242
Clip gradient :  0.18537671483384546
1.5364923477172852
Clip gradient :  0.2445170733825659
1.2395315170288086
Clip gradient :  0.8985193819853546
0.4145965576171875
Clip gradient :  0.5612500392137285
0.1454639434814453
Clip gradient :  0.12626990587414183
0.1023101806640625
Clip gradient :  0.05755485032470367


## Тестируем LSTM

In [15]:
### rnn.eval()
hh = torch.zeros(rnn.hid_c_2.in_features)
cc = torch.zeros(rnn.hid_c_2.in_features)
id =0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh, cc = rnn(x, hh, cc)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [16]:
#Написать реализацию GRU и обучить предсказывать слово

class GRU(nn.Module):
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(GRU, self).__init__()
        self.tg = nn.Tanh()
        
        self.sigm = nn.Sigmoid()
        
        self.hid_u_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_u_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.hid_r_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_r_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
        self.hid_h_1 = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hid_h_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)  
        
        self.out = nn.Linear(in_features=hidden_size, out_features=out_size) 
        
    def forward(self, x, prev_hidden):
        
        u = self.sigm(self.hid_u_1(x) + self.hid_u_2(prev_hidden))
        r = self.sigm(self.hid_r_1(x) + self.hid_r_2(prev_hidden))
        h_1 = self.tg(self.hid_h_1(x) + self.hid_h_2(r * prev_hidden))
        h = ((1 - u) * h_1) + (u * prev_hidden)
        output = self.out(h)
        return output, h

## Инициализируем переменные

In [17]:
ds = WordDataSet(word=word)
rnn = GRU(in_size=ds.vec_size, hidden_size=8, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 100
optim     = SGD(rnn.parameters(), lr = 0.1, momentum=0.9)

## Обучаем  GRU

In [18]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hid_u_2.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

71.58972930908203
Clip gradient :  4.96453132858786
57.73640441894531
Clip gradient :  6.801313649382623
36.154781341552734
Clip gradient :  8.235216859516882
40.8128662109375
Clip gradient :  20.397599273181275
24.074424743652344
Clip gradient :  10.30891107345717
14.305188179016113
Clip gradient :  7.274720719042617
7.27403450012207
Clip gradient :  2.426016169940197
3.599790096282959
Clip gradient :  1.3004842161153347
1.127805233001709
Clip gradient :  1.5290294618058125
0.3359670639038086
Clip gradient :  0.9887720768294291


## Тестируем GRU

In [19]:
### rnn.eval()
hh = torch.zeros(rnn.hid_u_2.in_features)
id =0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
