In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [2]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [3]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [4]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [232]:
word = 'ololoasdasddqweqw123456789'
#word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [233]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [234]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
        output = self.outweight(hidden)
        
        return output, hidden

## Инициализация переменных 

In [254]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 10000
optim     = SGD(rnn.parameters(), lr = 0.001, momentum=0.9)

# Обучение

In [255]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm=1)    
    optim.step()

72.76573181152344
Clip gradient :  7.424330306597151
72.28658294677734
Clip gradient :  6.8001687682390255
71.56970977783203
Clip gradient :  5.850411206458492
70.86940002441406
Clip gradient :  5.1828639895514925
70.20771789550781
Clip gradient :  4.866023158072851
69.56620025634766
Clip gradient :  4.703865331978205
68.93803405761719
Clip gradient :  4.618980548594914
68.31706237792969
Clip gradient :  4.6015285773913295
67.69499969482422
Clip gradient :  4.628361074147418
67.0657958984375
Clip gradient :  4.677260204404777
66.42640686035156
Clip gradient :  4.737456407803565
65.775390625
Clip gradient :  4.801360433584209
65.11256408691406
Clip gradient :  4.861388265476922
64.43895721435547
Clip gradient :  4.9112052579387955
63.75661087036133
Clip gradient :  4.946634255795367
63.06826400756836
Clip gradient :  4.9659682497813415
62.37687683105469
Clip gradient :  4.969683527387066
61.68523025512695
Clip gradient :  4.959565704194863
60.99580001831055
Clip gradient :  4.9378417345

16.59884262084961
Clip gradient :  5.900293981214814
16.610668182373047
Clip gradient :  30.010774861183855
16.594438552856445
Clip gradient :  32.41509352255174
16.53795623779297
Clip gradient :  23.39244487324397
16.557275772094727
Clip gradient :  31.814211430098055
16.528730392456055
Clip gradient :  28.534126613412777
16.47156524658203
Clip gradient :  16.28357859409582
16.4298038482666
Clip gradient :  2.828673989924226
16.39084815979004
Clip gradient :  3.9169409435369227
16.43103790283203
Clip gradient :  32.53373932797405
16.335765838623047
Clip gradient :  4.530563744642722
16.28983497619629
Clip gradient :  1.6838224692952897
16.23444366455078
Clip gradient :  6.006278692351697
16.234241485595703
Clip gradient :  24.98332476620495
16.170970916748047
Clip gradient :  6.726463102704335
16.231517791748047
Clip gradient :  28.780480633979668
16.1967830657959
Clip gradient :  31.730587529324936
16.180641174316406
Clip gradient :  25.033719063675967
16.186847686767578
Clip gradien

Clip gradient :  28.51373936805839
14.943370819091797
Clip gradient :  28.404910494319424
14.93516731262207
Clip gradient :  28.420647710910643
14.926944732666016
Clip gradient :  28.361648890331686
14.918763160705566
Clip gradient :  28.38050116617848
14.910560607910156
Clip gradient :  28.34595092901849
14.902385711669922
Clip gradient :  28.361960142926105
14.894186973571777
Clip gradient :  28.341413648457188
14.886018753051758
Clip gradient :  28.35566246217309
14.87783432006836
Clip gradient :  28.34300319948722
14.869671821594238
Clip gradient :  28.35742883256905
14.861509323120117
Clip gradient :  28.354286723461104
14.85336685180664
Clip gradient :  28.36922419179493
14.845222473144531
Clip gradient :  28.371850178978242
14.837100982666016
Clip gradient :  28.389285020172448
14.828975677490234
Clip gradient :  28.394613341390155
14.82086181640625
Clip gradient :  28.409496431820727
14.812766075134277
Clip gradient :  28.420484314156674
14.804677963256836
Clip gradient :  28.4

13.830977439880371
Clip gradient :  31.80920933174213
13.824468612670898
Clip gradient :  31.83366204850462
13.817981719970703
Clip gradient :  31.86070252122647
13.811494827270508
Clip gradient :  31.88488482106193
13.805023193359375
Clip gradient :  31.911531238066395
13.79855728149414
Clip gradient :  31.936333054917927
13.792104721069336
Clip gradient :  31.9621472391723
13.785654067993164
Clip gradient :  31.985856199915215
13.779217720031738
Clip gradient :  32.010801907544696
13.77279281616211
Clip gradient :  32.036072387585165
13.766380310058594
Clip gradient :  32.06180915725324
13.759973526000977
Clip gradient :  32.08767893467245
13.753583908081055
Clip gradient :  32.11445990323772
13.747199058532715
Clip gradient :  32.1405566937722
13.740822792053223
Clip gradient :  32.16531535732032
13.734456062316895
Clip gradient :  32.19066482768433
13.728096008300781
Clip gradient :  32.21457418945461
13.721746444702148
Clip gradient :  32.23977869781786
13.715414047241211
Clip gra

12.93625259399414
Clip gradient :  35.359391292574635
12.930984497070312
Clip gradient :  35.38182355396243
12.92572021484375
Clip gradient :  35.40250056529535
12.920461654663086
Clip gradient :  35.423626191884374
12.915205955505371
Clip gradient :  35.44293140376702
12.909965515136719
Clip gradient :  35.46448110277357
12.904726028442383
Clip gradient :  35.48634702441777
12.899494171142578
Clip gradient :  35.50745600111218
12.89427661895752
Clip gradient :  35.52951739222085
12.889045715332031
Clip gradient :  35.54717468175582
12.88383674621582
Clip gradient :  35.56871564395153
12.878633499145508
Clip gradient :  35.59003007989981
12.873429298400879
Clip gradient :  35.60931182958307
12.868232727050781
Clip gradient :  35.62925609807657
12.863044738769531
Clip gradient :  35.65003184467018
12.85785961151123
Clip gradient :  35.66968537501199
12.85268783569336
Clip gradient :  35.69075332498178
12.847515106201172
Clip gradient :  35.71038425801556
12.842358589172363
Clip gradient

12.197303771972656
Clip gradient :  38.27025552768443
12.192873001098633
Clip gradient :  38.28866123869841
12.188436508178711
Clip gradient :  38.30250802135169
12.184017181396484
Clip gradient :  38.32041950927971
12.179603576660156
Clip gradient :  38.338626515652415
12.175189018249512
Clip gradient :  38.3557435816237
12.17078971862793
Clip gradient :  38.375587499801874
12.1663818359375
Clip gradient :  38.391851749073716
12.161979675292969
Clip gradient :  38.40773134134144
12.157583236694336
Clip gradient :  38.42389935043138
12.1531982421875
Clip gradient :  38.44247645300513
12.148807525634766
Clip gradient :  38.45777328573952
12.144430160522461
Clip gradient :  38.47597849565185
12.140056610107422
Clip gradient :  38.494278997496096
12.135675430297852
Clip gradient :  38.50983932245549
12.131309509277344
Clip gradient :  38.52598576672353
12.126953125
Clip gradient :  38.545217471024706
12.12259292602539
Clip gradient :  38.56085292819784
12.118240356445312
Clip gradient :  

11.571795463562012
Clip gradient :  40.66054252363001
11.567962646484375
Clip gradient :  40.671914581786474
11.564146041870117
Clip gradient :  40.68701878127895
11.56032943725586
Clip gradient :  40.70198343781049
11.55651569366455
Clip gradient :  40.7154211467107
11.552699089050293
Clip gradient :  40.72885962198501
11.548891067504883
Clip gradient :  40.743232382999
11.545087814331055
Clip gradient :  40.757756333298865
11.541289329528809
Clip gradient :  40.77208574824347
11.53748893737793
Clip gradient :  40.78536373575359
11.533689498901367
Clip gradient :  40.79813190376843
11.529902458190918
Clip gradient :  40.812750637831044
11.526117324829102
Clip gradient :  40.82827569555423
11.522327423095703
Clip gradient :  40.84028681114946
11.518548965454102
Clip gradient :  40.85550113632431
11.514774322509766
Clip gradient :  40.86977497957636
11.510998725891113
Clip gradient :  40.883633535432345
11.50722885131836
Clip gradient :  40.89748440680556
11.50345230102539
Clip gradient

# Тестирование

In [256]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово

In [220]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [221]:
class LSTM(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(LSTM, self).__init__()        
        self.xc    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.hc    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.xi    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.hi    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.xo    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.ho    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.xt    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.ht    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.activation  = nn.Tanh()
        self.activ       = nn.Sigmoid()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden,cellst1):
        Ccellst = self.activ(self.xc(x) + self.hc(prev_hidden))
        ingate = self.activ(self.xi(x) + self.hi(prev_hidden))
        forgate = self.activ(self.xo(x) + self.ho(prev_hidden))
        outgate = self.activ(self.xt(x) + self.ht(prev_hidden))
        cellst = forgate*cellst1 + ingate*Ccellst
        hidden = outgate*self.activation(cellst)
        output = self.outweight(hidden)
        return output, hidden ,cellst

## Инициализация переменных 


In [319]:
ds = WordDataSet(word=word)
lstm = LSTM(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 1000
optim     = SGD(lstm.parameters(), lr = 0.1, momentum=0.9)

# Обучение

In [320]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(lstm.hc.in_features)
    cl = torch.zeros(lstm.hc.out_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        #print(x.size())
        target =  torch.LongTensor([next_sample])

        y, hh ,cl = lstm(x, hh, cl)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=1)    
    optim.step()

72.26292419433594
Clip gradient :  3.7709233241206026
68.93780517578125
Clip gradient :  1.9001808847230268
63.14332962036133
Clip gradient :  2.358402027178236
52.52511978149414
Clip gradient :  2.3223270124266007
41.89170455932617
Clip gradient :  1.612955956993511
32.48084259033203
Clip gradient :  1.2231775853116444
25.17783546447754
Clip gradient :  0.825619234542956
20.623340606689453
Clip gradient :  1.3962368354649226
18.075407028198242
Clip gradient :  4.969852318573088
20.976055145263672
Clip gradient :  12.307082350109836
18.93661117553711
Clip gradient :  9.123636560090237
20.56926727294922
Clip gradient :  7.638050834974273
17.7988224029541
Clip gradient :  5.134077667119628
19.129968643188477
Clip gradient :  5.811157728007747
20.64801597595215
Clip gradient :  6.460114425652703
18.362600326538086
Clip gradient :  5.502865885357311
17.136390686035156
Clip gradient :  4.1685992448596965
15.018474578857422
Clip gradient :  2.6919803111047416
13.453314781188965
Clip gradient

# Тестирование

In [321]:
lstm.eval()
hh = torch.zeros(lstm.hc.in_features)
cl = torch.zeros(lstm.hc.out_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh ,cl = lstm(x, hh, cl)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [260]:
class GRU(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(GRU, self).__init__()        
        self.xu    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.hu    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.xr    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.hr    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
        self.xh    = nn.Linear(in_features=in_size, out_features=hidden_size,bias=True)
        self.hh    = nn.Linear(in_features=hidden_size, out_features=hidden_size,bias=True)
    
        self.activation  = nn.Tanh()
        self.activ       = nn.Sigmoid()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        update = self.activ(self.xu(x) + self.hu(prev_hidden))
        reset = self.activ(self.xr(x) + self.hr(prev_hidden))
        hidden1=self.activation(self.xh(x) + self.hh(reset*prev_hidden))
        hidden = (1-update)*hidden1+update*prev_hidden
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [338]:
ds = WordDataSet(word=word)
gru = GRU(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 860
optim     = SGD(gru.parameters(), lr = 0.01, momentum=0.9)

# Обучение

In [339]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(gru.hh.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        #print(x.size())
        target =  torch.LongTensor([next_sample])

        y, hh  = gru(x, hh)
        
        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=5))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(gru.parameters(), max_norm=1)    
    optim.step()

72.01179504394531
Clip gradient :  4.694825276933077
69.50656127929688
Clip gradient :  3.3663715714490667
66.61558532714844
Clip gradient :  3.079606120261528
63.32560729980469
Clip gradient :  3.6284331175995077
59.18557357788086
Clip gradient :  3.7364832613917076
54.932796478271484
Clip gradient :  3.567781675470245
50.814720153808594
Clip gradient :  3.470458974828412
46.81064987182617
Clip gradient :  3.299330972704257
43.0397834777832
Clip gradient :  3.135889592990137
39.607666015625
Clip gradient :  2.8348292938874997
36.649208068847656
Clip gradient :  2.4823671139857537
34.050052642822266
Clip gradient :  2.2420775573171734
31.698883056640625
Clip gradient :  2.0921583834373716
29.48609733581543
Clip gradient :  2.0327580440517625
27.43887710571289
Clip gradient :  1.9745101894171024
25.485923767089844
Clip gradient :  1.789851720823352
23.696613311767578
Clip gradient :  1.613829901896937
22.110973358154297
Clip gradient :  1.4535225092389665
20.689916610717773
Clip gradien

# Тестирование

In [340]:
gru.eval()
hh = torch.zeros(gru.hh.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = gru(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
