In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

from time import time
import copy

In [2]:
seq_len = 10
batch_size = 3
hidden_size = 5
num_layers = 1

### Prepare Data

In [17]:
inputs = Variable(torch.randn(batch_size, seq_len, hidden_size))
labels = Variable(torch.ones(batch_size, seq_len).long())

inputs.size(), labels.size()
print(inputs)

Variable containing:
(0 ,.,.) = 
 -0.4896 -0.2160 -0.6313 -0.3543  0.2337
 -0.0718  0.5989  0.8464  1.5978 -0.0819
 -0.9158  0.5622 -0.7768 -1.1263 -1.0465
  0.4506 -0.0126 -0.8850  0.0562 -2.5094
  0.7123  0.3983 -0.7156 -0.0581  1.9640
 -0.4603 -1.1451  0.4822  1.2087  0.1134
  0.3681  1.4029  0.4841 -0.3085  0.6774
  1.8666  0.7289 -0.7285  1.5165 -0.9713
  0.1606 -0.0850  0.5290  0.8232  0.3798
  0.1931 -1.1215 -2.9263  2.3361 -0.8520

(1 ,.,.) = 
  1.2109 -0.0862 -0.1067 -0.5303 -0.1768
  0.5569 -0.1062 -0.5554  0.2678 -1.4154
  0.7367 -2.3189 -0.0328  0.9362  1.1795
  1.5108 -1.2707  0.3172 -2.9128  0.0304
 -0.6279  0.2067 -1.7887 -1.0441 -0.2283
  0.5645 -0.1671  0.1792 -0.0699  0.9892
 -0.2849  0.5176  0.7113  1.3107  0.9954
  0.8644 -1.4601 -0.4753 -0.0710  0.0877
  0.4778  2.1989  1.9863 -0.3994  0.1753
  0.1334 -1.6774 -0.4392  1.6171  1.8517

(2 ,.,.) = 
 -1.0962 -0.4176 -0.3727  1.3332  0.1398
 -0.0092 -0.8003  1.1821  0.6830 -0.2044
 -0.8851 -1.2152 -0.2499 -1.8230 -0.459

### Define model, Loss function, optimizer

In [6]:
rnn = nn.RNN( hidden_size, hidden_size, num_layers, batch_first=True)
h_n = Variable(torch.randn(1, batch_size, hidden_size))

loss_fn = CrossEntropyLoss()
opt = SGD(rnn.parameters(), lr=0.01)

### RNN module

In [7]:
start = time()

for i in range(1000):
    loss = 0
    
    out, last_h = rnn(inputs, h_n)
    
    # out: [batch_size, seq_len, hidden_size]
    # lables: [batch_size, seq_len]
    
    for j in range(seq_len):
        loss += loss_fn(out[:,j,:], labels[:,j])
        
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if (i+1) % 100 == 0:
        print(loss)

print(f'{time() - start:.2f}')

Variable containing:
 5.1214
[torch.FloatTensor of size 1]

Variable containing:
 4.7259
[torch.FloatTensor of size 1]

Variable containing:
 4.5771
[torch.FloatTensor of size 1]

Variable containing:
 4.4747
[torch.FloatTensor of size 1]

Variable containing:
 4.4322
[torch.FloatTensor of size 1]

Variable containing:
 4.4101
[torch.FloatTensor of size 1]

Variable containing:
 4.3961
[torch.FloatTensor of size 1]

Variable containing:
 4.3863
[torch.FloatTensor of size 1]

Variable containing:
 4.3790
[torch.FloatTensor of size 1]

Variable containing:
 4.3733
[torch.FloatTensor of size 1]

16.22


### RNNCell

In [9]:
rnncell = nn.RNNCell(hidden_size, hidden_size)
loss_fn = CrossEntropyLoss()
opt = SGD(rnncell.parameters(), lr=0.01)

In [11]:
inputs = Variable(torch.randn(batch_size,seq_len, hidden_size))
labels = Variable(torch.ones(batch_size, seq_len).long())
h = Variable(torch.randn(batch_size,hidden_size))

In [14]:
start = time()
for i in range(1000):
    loss = 0
    
    h_next = Variable(h.data.new(batch_size,hidden_size))

    for j in range(seq_len):
        h_next = rnncell(inputs[:,j,:], h_next)
        loss += loss_fn(h_next, labels[:, j])
        
    opt.zero_grad()
    loss.backward()
    opt.step()

    if (i+1) % 5 == 0:
        print(loss)

print(f'{time() - start:.2f}')

Variable containing:
 4.6914
[torch.FloatTensor of size 1]

Variable containing:
 4.3650
[torch.FloatTensor of size 1]

Variable containing:
 4.3630
[torch.FloatTensor of size 1]

Variable containing:
 4.3580
[torch.FloatTensor of size 1]

Variable containing:
 4.3634
[torch.FloatTensor of size 1]

Variable containing:
 4.3626
[torch.FloatTensor of size 1]

Variable containing:
 4.3627
[torch.FloatTensor of size 1]

Variable containing:
 4.3623
[torch.FloatTensor of size 1]

Variable containing:
 4.3621
[torch.FloatTensor of size 1]

Variable containing:
 4.3620
[torch.FloatTensor of size 1]

Variable containing:
 4.6896
[torch.FloatTensor of size 1]

Variable containing:
 4.3617
[torch.FloatTensor of size 1]

Variable containing:
 4.3578
[torch.FloatTensor of size 1]

Variable containing:
 4.3657
[torch.FloatTensor of size 1]

Variable containing:
 4.3612
[torch.FloatTensor of size 1]

Variable containing:
 4.3559
[torch.FloatTensor of size 1]

Variable containing:
 4.3610
[torch.Floa

Variable containing:
 4.3505
[torch.FloatTensor of size 1]

Variable containing:
 4.3464
[torch.FloatTensor of size 1]

Variable containing:
 4.3499
[torch.FloatTensor of size 1]

Variable containing:
 4.3498
[torch.FloatTensor of size 1]

Variable containing:
 4.3498
[torch.FloatTensor of size 1]

Variable containing:
 4.3497
[torch.FloatTensor of size 1]

Variable containing:
 4.6768
[torch.FloatTensor of size 1]

Variable containing:
 4.3454
[torch.FloatTensor of size 1]

Variable containing:
 4.3527
[torch.FloatTensor of size 1]

Variable containing:
 4.3352
[torch.FloatTensor of size 1]

Variable containing:
 4.3468
[torch.FloatTensor of size 1]

Variable containing:
 4.3459
[torch.FloatTensor of size 1]

Variable containing:
 4.3467
[torch.FloatTensor of size 1]

Variable containing:
 4.3493
[torch.FloatTensor of size 1]

Variable containing:
 4.6764
[torch.FloatTensor of size 1]

Variable containing:
 4.6764
[torch.FloatTensor of size 1]

Variable containing:
 4.3522
[torch.Floa

KeyboardInterrupt: 