In [20]:
import torch
import torch.nn as nn
from utils import TextProcess, TextGenerator

In [3]:
embed_size = 128    #Input features to the LSTM
hidden_size = 1024  #Number of LSTM units
num_layers = 1
num_epochs = 20
batch_size = 20
timesteps = 30
learning_rate = 0.002

In [4]:
corpus = TextProcess()

In [5]:
rep_tensor = corpus.get_data('alice.txt', batch_size=batch_size)

In [11]:
vocab_size = len(corpus.dictionary)
num_batches = rep_tensor.shape[1]//timesteps

In [9]:
rep_tensor.shape

torch.Size([20, 1652])

In [28]:
embed_size

128

In [29]:
timesteps

30

In [38]:
# Get model
model = TextGenerator(vocab_size, embed_size, hidden_size, num_layers)

In [26]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [42]:
for epoch in range(num_epochs):
    # Set zeros to be the initial hidden and cell states
    states = (torch.zeros(num_layers, batch_size, hidden_size),
             torch.zeros(num_layers, batch_size, hidden_size))
    for i in range(0, rep_tensor.size(1) - timesteps, timesteps):
        # Get mini-batch inputs and targets
        inputs = rep_tensor[:, i:i+timesteps]
        targets = rep_tensor[:, (i+1):(i+1)+timesteps]
        # 600 = 20 * 30 = batch_size * timesteps
        outputs, _ = model(inputs, states) # outputs is 600x1
        loss = loss_fn(outputs, targets.reshape(-1)) # reshape to be 600
        
        # Backpropagation and Weight Update
        model.zero_grad()
        loss.backward()
        
        # Try no gradient clipping
        
        optimizer.step()
        
        step = (i+1) // timesteps
        if step % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

outputs


AttributeError: 'NoneType' object has no attribute 'data'

In [33]:
target = torch.tensor([[   1,    2,    3,    4,    5,    5,    6,    7,    8,    9,   10,   11,
           12,   13,   14,   15,   16,   17,   18,    3,    5,   19,   20,   13,
           21,   22,    9,   23,   24,   25],
        [ 621,   28,  622,  623,  238,  624,  110,   44,  625,    5,  626,  627,
          628,  285,  251,  285,  629,   34,  196,  630,   20,  110,  251,  285,
          631,  632,    5,  633,   88,  634],
        [ 305, 1075,  272,  554,  555,   20,   16,   82,    5,  204,   80,  744,
         1076,   55,   27,  161,  132,  265,  274,   69, 1077,  148, 1078,   20,
            5,  103,  104,  112,    9,  552],
        [1462,    5,    3, 1463, 1464, 1129,    7, 1465,   15,    3, 1466,    7,
          379, 1467,    5,    9,   15,    3, 1468,  610, 1469, 1470,   20,   28,
          485,   13,  468,   94,    5, 1471],
        [  13, 1844,   38,    3,  919,   20,   27,  129,   73,    5, 1803,  575,
          576,  110,    3, 1232,   28,  975, 1055, 1845,   20,    7,  194,    5,
          574,    9, 1763, 1055, 1846,    5],
        [ 103,    5, 2152, 2153,    9,  552, 2154, 2136,    3, 2155, 1359,  358,
            9,    5,  552,  153,    3, 2156,    5,    5,  937,   49, 2125,  309,
            9,  262,  153,    3, 1976, 1608],
        [ 272,    3, 2364,    5,    5, 2423,  110, 2424,  272,  555, 2425,  153,
           16, 2426,   55,   54,   55,   27,    5, 1675,    5,    5, 2427,  272,
            3, 2364,    5,    5,    6,   46],
        [   5,    5,  265, 2630, 2727, 2729, 1699,  272,  555,  610,    7,   44,
           11, 2730,    5, 2731, 1702,  367, 2721,  408, 2732,  119,   55,   94,
           55, 2686,  378,  285,    5, 2547],
        [  16, 2991, 2992,   93,   13,    5,   41, 2604,   34,   68,  193,   80,
            3,   59,   20, 2993, 2994,  151,    3, 2995, 2996,    5, 2997, 2998,
            9, 2391,  483,   18, 1235, 2999],
        [   5,  174,  756,  452,  526,  250,  130,   27,  491,   73, 3242,   34,
          334, 2934,    5,  413,    9,  176, 3243,   34,  384,   69, 3215, 3167,
          148, 3244,  103,  642,    5,  395],
        [ 469, 3491,    5,    5, 3492,  212,   44,  116,  208,  203, 1934,  367,
         3493,    3, 3252,  320,    5,   38,   44,  224, 3494,   40,  329, 3495,
          203, 3496, 3497,   20, 3498,   20],
        [3426,  610,  562, 2822,   20, 3748,   38, 1761,    5,    5, 3749,  272,
            3, 3750, 3037,   16,  550, 3751,  964, 1533,    9,    5,  555,   27,
          161,  132, 3752,  632, 3753, 3754],
        [  20,   27,   92,  173,  167,    3,  186,    5,   13,  289,  114, 1078,
           55,    3,  749,    7,   38,  238, 2184,  110,   27,  137,  563,    5,
           64,   34,    7,   16, 2391,   25],
        [1099, 3977,    9,  372,  115,   87,  110,   15,    5,    3, 1314,   13,
          575,  263, 1603,   25,   87,  208,  203,   35, 3977, 4167,   20,  117,
            3,    5, 4162,  500,    3, 3789],
        [   5, 4393, 4394, 4395,    5,    5, 4396,   80,   44, 4397,   55,   44,
         4398,  822,    3, 4189,    5,    5, 1647, 1628,    3, 4009, 4169, 4328,
         4399, 2574,  814,    9, 4400,    5],
        [  34,   68,  117,  262,    5, 1036,   20,   27,  161,   18,   38,   44,
         1313, 4614,    5,    5,  265, 3940,   15, 1055,  787,   20, 4615,   80,
          218, 4616,    5,  242,    3, 4617],
        [ 411,  931,  309,   73,   11, 3539,   20, 1042,  107,    9,    3,    5,
          615,  486,   13,    3, 4854,    5,    5, 4161,  270,  116,    3, 2813,
           28,  137,  765,  252, 2828,  114],
        [  69,   44,    5, 2807,  375,   15,    3, 5045, 5046, 5047,    5,    5,
         1698,  274,  112,  485, 3153,  272,    3, 3789, 5048,   34,    7,  375,
            9,    5, 4200,  426, 3276, 5049],
        [  61,   80,    5,  117,  329,  620, 5284,   20,  577,   44,   65,   38,
          117,  329,  620, 5285,    5, 5286,   16,   51, 5287,   20,    3, 5288,
         4780, 5289,    5,    5,  879, 5290],
        [  15,   44, 5612,  610, 5613,    5,  285,   38, 4741, 3103,   15, 5614,
         2962, 5615, 2895,   13, 5616,  110, 5617,    5, 2737,  173, 4097,    9,
            3, 3392,   13,    3, 1984, 5292]])

In [37]:
target.reshape(-1).shape

torch.Size([600])

In [None]:
[  25,   26,   27,   28,   29,   30,    3,    5,   31,   16,   17,    7,
           32,   33,   34,   28,   35,   36,   25,   37,   38,    5,   39,   40,
           41,   42,    3,   43,   13,   44]

In [None]:
[  26,   27,   28,   29,   30,    3,    5,   31,   16,   17,    7,   32,
           33,   34,   28,   35,   36,   25,   37,   38,    5,   39,   40,   41,
           42,    3,   43,   13,   44,   45]

In [18]:
for i in [ 0,  1,  2,  3,  4,  5,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18,  3,  5, 19, 20, 13, 21, 22,  9, 23, 24, 25, 26, 27, 28, 29, 30,
         3,  5, 31, 16, 17,  7, 32, 33, 34, 28, 35, 36, 25, 37]:
    print(corpus.dictionary.idx2word[i])

CHAPTER
I.
Down
the
Rabbit-Hole
<eos>
<eos>
Alice
was
beginning
to
get
very
tired
of
sitting
by
her
sister
on
the
<eos>
bank,
and
of
having
nothing
to
do:
once
or
twice
she
had
peeped
into
the
<eos>
book
her
sister
was
reading,
but
it
had
no
pictures
or
conversations


In [22]:
for i in range(0, rep_tensor.size(1) - timesteps, timesteps):
    print(i)

0
30
60
90
120
150
180
210
240
270
300
330
360
390
420
450
480
510
540
570
600
630
660
690
720
750
780
810
840
870
900
930
960
990
1020
1050
1080
1110
1140
1170
1200
1230
1260
1290
1320
1350
1380
1410
1440
1470
1500
1530
1560
1590
1620
