# LSTM
<img src="./LSTM.png">

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np


In [2]:
# 1. Hyper Parameters
input_size = 28
sequence_size = 28
hidden_size = 128
num_layers = 1
num_classes = 10

learning_rate = 0.01
batch_size = 1
ephoc_size = 2


In [3]:
# 2. Preparing datasets
    # MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

    # Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [8]:
# 3. Build a model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fnn = nn.Linear(hidden_size, num_classes)

    def forward(self, x):

        out, _ = self.lstm(x)
        print(out, out.size())
        out_tmp = out[:, -1, :] 
        print(out_tmp)   # 1x128 takes the last hidden layer  
        output = self.fnn(out_tmp) #output = 1x10
        #print(output)
        
        return output

In [9]:
# 4. Generate the model
model = RNNModel(input_size, hidden_size, num_layers, num_classes)

In [10]:
# 5. Set Loss and Optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
# 6. Train
for ephoc in range(ephoc_size):
    for idx, (images, labels) in enumerate(train_loader):
        # convert dataset as for Pytorch type
        images = Variable(images.view(-1, sequence_size, input_size))
        labels = Variable(labels)
        
        # Forward, Backward and Gradient decent
        optimizer.zero_grad()
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("loss:", loss.item())

tensor([[[ 7.0351e-05, -3.8359e-03,  2.3051e-02,  ...,  1.3764e-02,
           2.3978e-02, -7.3608e-03],
         [ 3.1582e-03, -6.9237e-03,  3.3783e-02,  ...,  2.4987e-02,
           3.3708e-02, -1.1957e-02],
         [ 5.8689e-03, -9.1283e-03,  3.8645e-02,  ...,  3.2525e-02,
           3.7763e-02, -1.4827e-02],
         ...,
         [ 1.4154e-02, -1.0473e-02,  4.1296e-02,  ...,  7.2624e-02,
           5.7798e-02, -4.3564e-02],
         [ 1.0614e-02, -1.2887e-02,  3.9037e-02,  ...,  6.2020e-02,
           5.1962e-02, -3.2387e-02],
         [ 9.5111e-03, -1.3586e-02,  3.8691e-02,  ...,  5.5264e-02,
           4.7955e-02, -2.6223e-02]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 0.0095, -0.0136,  0.0387,  0.0281, -0.0102,  0.0245, -0.0005,  0.0541,
         -0.0722, -0.0187,  0.0120,  0.0534, -0.0228, -0.0278,  0.0091,  0.0128,
         -0.0283, -0.0624, -0.0240,  0.0160,  0.0197, -0.0136,  0.0004,  0.0402,
         -0.0378,  0.0193,  0.0445,  0.0522, -0.0078, -0

tensor([[[-2.5082e-03, -1.0272e-02,  2.0683e-02,  ..., -3.0879e-07,
           1.6153e-02,  4.6446e-03],
         [-5.4771e-03, -1.8159e-02,  3.0220e-02,  ..., -5.1536e-03,
           1.6085e-02,  1.0561e-02],
         [-9.5936e-03, -1.9371e-02,  3.4904e-02,  ..., -1.0330e-02,
           9.7283e-03,  1.3472e-02],
         ...,
         [ 3.0839e-02,  9.6861e-03,  4.8860e-02,  ..., -3.1137e-02,
          -1.8442e-02, -4.0107e-03],
         [ 3.5519e-03,  1.0029e-02,  3.6388e-02,  ..., -1.8064e-02,
          -7.1463e-03,  8.9458e-03],
         [-6.5038e-03,  8.2099e-03,  3.5686e-02,  ..., -1.2257e-02,
          -1.0773e-03,  1.4645e-02]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0065,  0.0082,  0.0357, -0.0249,  0.0077, -0.0072,  0.0121,  0.0358,
         -0.0320,  0.0366, -0.0483,  0.0235,  0.0011, -0.0974, -0.0018,  0.0308,
         -0.0026, -0.0318,  0.0585,  0.0296,  0.0018, -0.0279,  0.0112,  0.0028,
         -0.0401,  0.0194,  0.0125,  0.0219,  0.0110, -0

tensor([[[-0.0058, -0.0016,  0.0126,  ..., -0.0122,  0.0096,  0.0050],
         [-0.0106,  0.0025,  0.0157,  ..., -0.0218,  0.0037,  0.0083],
         [-0.0122,  0.0117,  0.0151,  ..., -0.0277, -0.0037,  0.0087],
         ...,
         [ 0.0456,  0.0316, -0.0106,  ..., -0.1394, -0.0876, -0.0106],
         [ 0.0372,  0.0236, -0.0055,  ..., -0.1040, -0.0511, -0.0181],
         [ 0.0349,  0.0198,  0.0077,  ..., -0.0731, -0.0060, -0.0135]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 0.0349,  0.0198,  0.0077, -0.0363,  0.0118,  0.0269,  0.0183, -0.0256,
         -0.0078,  0.0851,  0.0117, -0.0264,  0.0047,  0.0298,  0.0429,  0.0029,
          0.0239,  0.0352, -0.0038,  0.0163, -0.0388, -0.0686,  0.0864, -0.0231,
          0.0034, -0.0259,  0.0896,  0.0569,  0.0260,  0.0275,  0.0923, -0.0817,
         -0.0064,  0.0388, -0.0361,  0.0197, -0.0287,  0.0074,  0.0120, -0.0146,
          0.0411, -0.0725,  0.0210, -0.0227,  0.0789,  0.0138,  0.0196,  0.0184,
         -

tensor([[[ 0.0022,  0.0075,  0.0176,  ..., -0.0175,  0.0113, -0.0010],
         [ 0.0034,  0.0200,  0.0222,  ..., -0.0291,  0.0064, -0.0029],
         [ 0.0090,  0.0351,  0.0231,  ..., -0.0351,  0.0023, -0.0067],
         ...,
         [ 0.1156,  0.0747,  0.0342,  ..., -0.0300,  0.0399, -0.0310],
         [ 0.0800,  0.0572,  0.0340,  ..., -0.0364,  0.0270, -0.0179],
         [ 0.0523,  0.0559,  0.0291,  ..., -0.0392,  0.0089, -0.0139]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 0.0523,  0.0559,  0.0291,  0.0540, -0.0452,  0.0898, -0.0530,  0.0061,
          0.0479,  0.1402, -0.0214,  0.0149, -0.0056, -0.0815,  0.0079,  0.0478,
          0.0185, -0.0134,  0.0517,  0.0517, -0.0331, -0.1535, -0.0337, -0.0023,
         -0.0260,  0.0287,  0.0038,  0.0263,  0.0276,  0.0021,  0.0121, -0.1101,
          0.0053, -0.0161, -0.0282, -0.0548, -0.0522, -0.0431,  0.0365,  0.0553,
          0.0051, -0.0208,  0.0185, -0.0656, -0.0262,  0.0045,  0.0270,  0.0474,
          

tensor([[[ 0.0105,  0.0170,  0.0253,  ..., -0.0263,  0.0125, -0.0036],
         [ 0.0243,  0.0428,  0.0364,  ..., -0.0462,  0.0099, -0.0074],
         [ 0.0473,  0.0730,  0.0444,  ..., -0.0621,  0.0098, -0.0117],
         ...,
         [ 0.3121,  0.3752,  0.0611,  ..., -0.4078,  0.0190, -0.0163],
         [ 0.3504,  0.3889,  0.0769,  ..., -0.3699,  0.0594, -0.0203],
         [ 0.3712,  0.3916,  0.0889,  ..., -0.3291,  0.0853, -0.0199]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 3.7121e-01,  3.9161e-01,  8.8861e-02,  2.8342e-01, -6.9397e-02,
          6.8895e-01, -2.7705e-01,  1.2550e-02,  9.4941e-02,  5.5295e-01,
          2.5414e-02, -4.2264e-03, -1.7427e-01, -6.0471e-03,  1.1039e-01,
          1.1610e-01,  4.2978e-02, -2.0076e-02,  3.6359e-02,  2.9447e-01,
         -2.1259e-01, -7.0483e-01, -6.7488e-02, -4.8643e-02, -4.0068e-01,
          1.1002e-01,  1.7935e-01,  3.6142e-01,  1.5514e-01, -4.9572e-03,
          2.9242e-01, -5.5549e-01, -7.4324e-02, -4.0

tensor([[[ 0.0102,  0.0167,  0.0257,  ..., -0.0238,  0.0162, -0.0042],
         [ 0.0217,  0.0364,  0.0352,  ..., -0.0353,  0.0204, -0.0079],
         [ 0.0384,  0.0545,  0.0403,  ..., -0.0394,  0.0268, -0.0113],
         ...,
         [ 0.1805,  0.1934,  0.0524,  ..., -0.1023,  0.0673, -0.0009],
         [ 0.1792,  0.1782,  0.0571,  ..., -0.0573,  0.1095, -0.0156],
         [ 0.1850,  0.1428,  0.0629,  ..., -0.0325,  0.1398, -0.0183]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 0.1850,  0.1428,  0.0629,  0.0892, -0.0401,  0.4161, -0.0393,  0.0229,
          0.0247,  0.2671, -0.0070,  0.0246, -0.0132, -0.0474,  0.0086,  0.1124,
         -0.0330, -0.0533,  0.0198,  0.3249, -0.0045, -0.5032, -0.0218,  0.0356,
         -0.0315,  0.0306,  0.1261,  0.4547,  0.0279,  0.0018,  0.1017, -0.2907,
         -0.0048, -0.0174,  0.0068, -0.0973, -0.0918, -0.0024,  0.0345,  0.0260,
         -0.0028, -0.1124,  0.0895, -0.1270,  0.0320,  0.1505, -0.0042,  0.3031,
          

tensor([[[-1.7427e-04,  7.8803e-03,  1.9648e-02,  ..., -1.3943e-02,
           2.2127e-02, -2.7600e-03],
         [-4.3788e-03,  9.8806e-03,  2.1468e-02,  ..., -1.2807e-02,
           3.0147e-02, -3.2792e-03],
         [-7.0671e-03,  8.6335e-03,  1.9606e-02,  ..., -7.4983e-03,
           3.3530e-02, -3.2764e-03],
         ...,
         [ 2.0551e-02,  5.3119e-02,  1.9711e-02,  ..., -7.2362e-06,
           5.9212e-02, -4.5317e-03],
         [ 1.3860e-02,  1.5877e-02,  1.9516e-02,  ...,  1.2544e-02,
           7.6157e-02, -2.0806e-03],
         [-2.7889e-03, -1.1676e-02,  1.4222e-02,  ...,  1.4931e-02,
           6.3670e-02,  1.0721e-03]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-2.7889e-03, -1.1676e-02,  1.4222e-02,  2.3208e-03, -1.8298e-03,
          5.7777e-02,  4.4590e-02, -2.9750e-03,  7.5734e-04,  7.7589e-02,
         -3.8256e-02,  3.5659e-02,  3.4817e-02, -9.2912e-02, -4.8518e-02,
          5.2766e-02, -6.5320e-02, -2.1106e-02,  6.9759e-04,  1.1184e-01,
   

tensor([[[-0.0173, -0.0060,  0.0082,  ...,  0.0010,  0.0365, -0.0033],
         [-0.0395, -0.0250,  0.0011,  ...,  0.0152,  0.0519, -0.0032],
         [-0.0597, -0.0457, -0.0079,  ...,  0.0269,  0.0494, -0.0027],
         ...,
         [-0.1291, -0.1465, -0.0456,  ...,  0.0600,  0.0167, -0.0038],
         [-0.1277, -0.1087, -0.0335,  ...,  0.0201, -0.0216, -0.0091],
         [-0.0876, -0.0406, -0.0019,  ..., -0.0258, -0.0244, -0.0143]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0876, -0.0406, -0.0019,  0.0389,  0.0270, -0.0781,  0.0317, -0.0588,
          0.0287, -0.0723,  0.0610,  0.0005,  0.0178, -0.0090, -0.0877, -0.0140,
          0.0364,  0.0685, -0.0431, -0.0859,  0.0435,  0.1021, -0.0125,  0.0058,
         -0.0219,  0.0130, -0.1148, -0.0610,  0.0062, -0.0181, -0.0345,  0.1055,
          0.0720,  0.0265, -0.0413,  0.0022,  0.0162, -0.1222,  0.0431,  0.0657,
         -0.0439,  0.0616,  0.0347, -0.0019, -0.1294, -0.0657,  0.0823, -0.0619,
         -

tensor([[[-0.0227, -0.0192,  0.0051,  ...,  0.0089,  0.0424, -0.0086],
         [-0.0485, -0.0528, -0.0028,  ...,  0.0262,  0.0557, -0.0108],
         [-0.0662, -0.0839, -0.0079,  ...,  0.0337,  0.0439, -0.0165],
         ...,
         [-0.0317,  0.0021,  0.0130,  ...,  0.0449, -0.0054, -0.0031],
         [-0.0375, -0.0230,  0.0039,  ...,  0.0644,  0.0827, -0.0133],
         [-0.0590, -0.0868, -0.0055,  ...,  0.0744,  0.1144, -0.0114]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-5.9039e-02, -8.6790e-02, -5.4985e-03, -7.1647e-02,  2.2070e-02,
         -6.9624e-03,  9.3283e-02, -3.0281e-02, -3.0479e-02,  1.5151e-02,
         -1.7740e-02,  5.6758e-02,  4.3873e-02, -2.4133e-02, -6.9480e-02,
          2.3204e-02, -9.2653e-02, -2.5476e-02, -6.1864e-02,  3.8510e-02,
          9.9101e-02, -3.9394e-02, -3.4042e-05,  1.0294e-01,  5.8117e-02,
         -5.2549e-02, -3.6115e-02,  2.5163e-01, -9.7167e-02,  1.3865e-02,
         -1.2447e-02, -1.0184e-02,  3.1789e-02,  1.0

tensor([[[-0.0233, -0.0242,  0.0049,  ...,  0.0122,  0.0459, -0.0102],
         [-0.0478, -0.0631, -0.0023,  ...,  0.0318,  0.0613, -0.0137],
         [-0.0611, -0.0991, -0.0041,  ...,  0.0406,  0.0521, -0.0226],
         ...,
         [-0.0537, -0.1124,  0.0090,  ...,  0.0585,  0.0555, -0.0303],
         [-0.0587, -0.1327,  0.0083,  ...,  0.0527,  0.0515, -0.0373],
         [-0.0607, -0.1383,  0.0085,  ...,  0.0417,  0.0348, -0.0430]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-6.0709e-02, -1.3827e-01,  8.5147e-03, -1.0950e-01,  1.7291e-02,
         -8.3145e-02,  8.9293e-02, -2.1615e-02, -4.8759e-02, -2.2332e-02,
         -2.7049e-02,  4.8756e-02,  1.4970e-02, -2.8152e-02, -2.0621e-02,
          4.3547e-03, -5.0103e-02, -1.5221e-02, -4.6608e-02, -1.4574e-02,
          6.1338e-02,  5.2940e-02,  1.2277e-02,  4.9946e-02,  4.8044e-02,
         -2.1883e-02, -4.5227e-02,  3.3201e-02, -1.1616e-01,  5.5298e-03,
         -2.5926e-02,  4.4304e-02, -1.8986e-02, -1.2

tensor([[[-0.0256, -0.0258,  0.0063,  ...,  0.0130,  0.0339, -0.0053],
         [-0.0498, -0.0588,  0.0021,  ...,  0.0303,  0.0394, -0.0042],
         [-0.0600, -0.0831,  0.0049,  ...,  0.0367,  0.0303, -0.0071],
         ...,
         [-0.0469, -0.0212,  0.0139,  ...,  0.0280,  0.0542,  0.0020],
         [-0.0582, -0.0678,  0.0111,  ...,  0.0439,  0.0762,  0.0058],
         [-0.0672, -0.0974,  0.0099,  ...,  0.0491,  0.0605,  0.0030]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0672, -0.0974,  0.0099, -0.1089,  0.0034, -0.0397,  0.0745,  0.0254,
         -0.0455, -0.0038, -0.0360,  0.0191,  0.0063, -0.0342, -0.0274,  0.0176,
         -0.0414,  0.0022, -0.0230,  0.0092,  0.0698,  0.0016,  0.0606,  0.0463,
          0.0460, -0.0396,  0.0009,  0.1361, -0.0947,  0.0343, -0.0129, -0.0109,
         -0.0388, -0.0262,  0.0354, -0.0069, -0.0088,  0.0599, -0.0185, -0.0065,
          0.0831, -0.0356,  0.0693,  0.0278,  0.0081, -0.0447,  0.0092,  0.0110,
         -

tensor([[[-0.0246, -0.0235,  0.0071,  ...,  0.0109,  0.0274, -0.0029],
         [-0.0460, -0.0496,  0.0049,  ...,  0.0240,  0.0282, -0.0003],
         [-0.0533, -0.0643,  0.0096,  ...,  0.0262,  0.0192, -0.0009],
         ...,
         [-0.0860,  0.0384,  0.0177,  ..., -0.0160, -0.0771,  0.0376],
         [-0.0669,  0.0470,  0.0152,  ..., -0.0101, -0.0550,  0.0196],
         [-0.0453,  0.0364,  0.0206,  ..., -0.0027, -0.0056,  0.0114]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0453,  0.0364,  0.0206, -0.0555,  0.0562,  0.0604,  0.0161, -0.0050,
         -0.0261,  0.0958,  0.0320, -0.0335, -0.0640,  0.0013, -0.0048,  0.0231,
          0.0704,  0.1054, -0.0299, -0.0302,  0.0221, -0.1533,  0.2914, -0.0088,
         -0.0535, -0.0589,  0.0518,  0.2053,  0.0052,  0.0083, -0.1215, -0.2403,
         -0.0132, -0.0237,  0.0052,  0.0010, -0.0831,  0.0624,  0.0236, -0.0120,
          0.0981, -0.0396,  0.0175,  0.0178,  0.0539, -0.0889,  0.0623,  0.1588,
         -

tensor([[[-0.0255, -0.0213,  0.0068,  ...,  0.0083,  0.0184, -0.0028],
         [-0.0462, -0.0415,  0.0062,  ...,  0.0158,  0.0114, -0.0002],
         [-0.0529, -0.0477,  0.0114,  ...,  0.0121, -0.0024, -0.0009],
         ...,
         [-0.0231, -0.1174,  0.0535,  ..., -0.0073, -0.0273, -0.0214],
         [-0.0148, -0.0956,  0.0515,  ...,  0.0082, -0.0106, -0.0285],
         [-0.0331, -0.0804,  0.0319,  ...,  0.0151, -0.0082, -0.0130]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0331, -0.0804,  0.0319, -0.1036, -0.0150, -0.0460,  0.0450,  0.0156,
         -0.0645, -0.0061, -0.0281,  0.0190, -0.0021, -0.0325,  0.0104, -0.0089,
         -0.0175,  0.0383, -0.0025,  0.0104,  0.0293,  0.0088,  0.0547,  0.0148,
          0.0328, -0.0295,  0.0236,  0.0678, -0.0559,  0.0071,  0.0138,  0.0201,
         -0.0434, -0.0440,  0.0282, -0.0006, -0.0120,  0.0370, -0.0186,  0.0235,
          0.0818, -0.0293,  0.0202,  0.0181,  0.0137,  0.0283,  0.0264, -0.0012,
         -

tensor([[[-0.0230, -0.0188,  0.0049,  ...,  0.0097,  0.0176, -0.0056],
         [-0.0425, -0.0376,  0.0036,  ...,  0.0188,  0.0096, -0.0038],
         [-0.0430, -0.0427, -0.0356,  ...,  0.0699,  0.0358, -0.0134],
         ...,
         [-0.0165, -0.0320, -0.0221,  ...,  0.0257,  0.3347,  0.0015],
         [-0.0541, -0.0834, -0.0203,  ...,  0.0542,  0.1656,  0.0142],
         [-0.0739, -0.1039, -0.0018,  ...,  0.0435,  0.0574,  0.0114]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0739, -0.1039, -0.0018, -0.1760,  0.0079, -0.0657,  0.0609,  0.0212,
         -0.0813, -0.0048, -0.0524, -0.0106,  0.0019, -0.0117,  0.0147, -0.0017,
         -0.1134, -0.0397,  0.0363, -0.0399,  0.0250,  0.0277,  0.0726,  0.0085,
          0.0437, -0.0416, -0.0180,  0.0262, -0.1014,  0.0495, -0.0933, -0.0022,
         -0.0581, -0.0362, -0.0040,  0.0609,  0.0042,  0.0569, -0.0539, -0.0309,
          0.1910, -0.0602,  0.0257,  0.0795, -0.0003, -0.0498,  0.0005, -0.0186,
         -

tensor([[[-1.5867e-02, -1.6738e-02,  2.8108e-03,  ...,  1.3792e-02,
           7.8230e-03, -7.9257e-03],
         [-3.2540e-02, -3.2531e-02,  3.0652e-05,  ...,  2.6125e-02,
          -8.5082e-03, -5.9926e-03],
         [-3.5845e-02, -3.4752e-02,  2.9957e-03,  ...,  2.6153e-02,
          -2.8847e-02, -7.8221e-03],
         ...,
         [ 2.2157e-01,  1.3686e-01,  2.4994e-02,  ..., -9.7969e-02,
           1.8313e-01, -9.7451e-02],
         [ 1.2062e-02, -1.6971e-02, -1.9085e-02,  ...,  2.7586e-02,
           3.0130e-01, -2.3279e-03],
         [-4.5497e-02, -7.1611e-02, -2.9618e-02,  ...,  5.6710e-02,
           9.7031e-02,  1.5615e-02]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0455, -0.0716, -0.0296, -0.1129,  0.0055, -0.0377,  0.0681, -0.0134,
         -0.0306, -0.0170, -0.0477,  0.0060,  0.0675, -0.0114, -0.0329, -0.0135,
         -0.1677, -0.0671,  0.0109, -0.0014,  0.0612,  0.0032, -0.0365,  0.0520,
          0.0604, -0.0396,  0.0074,  0.1410, -0.0746,  0

tensor([[[-0.0121, -0.0178,  0.0045,  ...,  0.0129,  0.0041, -0.0096],
         [-0.0280, -0.0358,  0.0023,  ...,  0.0270, -0.0156, -0.0083],
         [-0.0348, -0.0406,  0.0050,  ...,  0.0290, -0.0398, -0.0078],
         ...,
         [-0.0032, -0.0611, -0.0201,  ...,  0.0325,  0.2100, -0.0071],
         [-0.0397, -0.1085, -0.0081,  ...,  0.0701,  0.0726,  0.0043],
         [-0.0661, -0.1068,  0.0038,  ...,  0.0589, -0.0108,  0.0154]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0661, -0.1068,  0.0038, -0.1549, -0.0281, -0.0756,  0.0222,  0.1194,
         -0.0373, -0.0236, -0.1009, -0.0264,  0.0017, -0.0205,  0.0161, -0.0077,
         -0.0680, -0.0640,  0.0638, -0.0508,  0.0032,  0.0467, -0.0215,  0.0022,
          0.0472,  0.0759, -0.0332,  0.0056, -0.0600,  0.0361, -0.0540,  0.0341,
         -0.0942, -0.0261, -0.0152,  0.0479,  0.0075,  0.0434, -0.0757,  0.0069,
          0.0662, -0.0711, -0.0106,  0.0149,  0.0185, -0.0354, -0.0063, -0.0431,
          

tensor([[[-0.0113, -0.0239,  0.0052,  ...,  0.0124,  0.0009, -0.0091],
         [-0.0278, -0.0478,  0.0042,  ...,  0.0238, -0.0245, -0.0075],
         [-0.0359, -0.0521,  0.0079,  ...,  0.0213, -0.0518, -0.0072],
         ...,
         [ 0.0151,  0.1764,  0.0052,  ..., -0.2749, -0.0598, -0.0144],
         [ 0.0112,  0.1093, -0.0035,  ..., -0.0915,  0.0783, -0.0515],
         [ 0.0022, -0.0418,  0.0070,  ...,  0.0279,  0.0944, -0.0332]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 0.0022, -0.0418,  0.0070,  0.0065,  0.0183, -0.0415,  0.0667, -0.0438,
         -0.0512, -0.0204,  0.0112,  0.0552,  0.0995,  0.0280, -0.0190, -0.0114,
         -0.1174, -0.0543, -0.0103,  0.0220,  0.0916, -0.0317, -0.0540,  0.0402,
          0.0377,  0.0386,  0.0269,  0.1418, -0.0184, -0.0664,  0.0924,  0.0036,
          0.0155,  0.0160,  0.1281,  0.0188,  0.0575,  0.0267, -0.0437,  0.0267,
         -0.0466,  0.0581,  0.0197,  0.0230, -0.0549,  0.1170, -0.0786,  0.0399,
          

tensor([[[-0.0115, -0.0308,  0.0049,  ...,  0.0124, -0.0048, -0.0086],
         [-0.0286, -0.0590,  0.0039,  ...,  0.0203, -0.0357, -0.0064],
         [-0.0366, -0.0587,  0.0072,  ...,  0.0120, -0.0632, -0.0072],
         ...,
         [ 0.0148,  0.0323,  0.0200,  ..., -0.0225, -0.0641, -0.0476],
         [ 0.0039, -0.0209,  0.0104,  ...,  0.0250, -0.0099, -0.0252],
         [-0.0276, -0.0857,  0.0052,  ...,  0.0436, -0.0380, -0.0065]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0276, -0.0857,  0.0052, -0.0562, -0.0085, -0.0725,  0.0565,  0.0733,
         -0.0194, -0.0150, -0.0507,  0.0081,  0.0436, -0.0102, -0.0026, -0.0125,
         -0.0381, -0.0255,  0.0246, -0.0245,  0.0383,  0.0073, -0.0374,  0.0182,
          0.0371,  0.0514, -0.0177,  0.0456, -0.0414, -0.0146,  0.0351,  0.0295,
         -0.0083,  0.0054,  0.0275,  0.0127,  0.0179,  0.0055, -0.0496,  0.0377,
         -0.0178, -0.0020, -0.0258,  0.0118, -0.0508,  0.0036, -0.0195, -0.0214,
          

tensor([[[-0.0078, -0.0400,  0.0048,  ...,  0.0146, -0.0182, -0.0044],
         [-0.0248, -0.0732,  0.0017,  ...,  0.0286, -0.0487, -0.0014],
         [-0.0379, -0.0877,  0.0002,  ...,  0.0336, -0.0699,  0.0025],
         ...,
         [-0.0580, -0.0647,  0.0034,  ...,  0.0481, -0.0966,  0.0094],
         [-0.0612, -0.0283,  0.0023,  ...,  0.0340, -0.1062,  0.0085],
         [-0.0536,  0.0009,  0.0069,  ...,  0.0209, -0.1128,  0.0035]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0536,  0.0009,  0.0069, -0.0744, -0.0124, -0.0294,  0.0085,  0.0571,
         -0.0277,  0.0248, -0.3071, -0.0478, -0.0163, -0.0678, -0.0607,  0.0220,
          0.0977,  0.0520,  0.0410, -0.0958, -0.0783,  0.0290, -0.0703,  0.0219,
         -0.0263,  0.0694, -0.1407, -0.0685,  0.0324,  0.0563, -0.0614,  0.0230,
         -0.0294, -0.0088, -0.0703, -0.0113, -0.0127, -0.0161, -0.0230,  0.0315,
         -0.0225,  0.0401, -0.0474,  0.0086,  0.0117, -0.0893,  0.1424, -0.0390,
          

tensor([[[-4.5458e-03, -4.3711e-02,  5.3130e-03,  ...,  1.4734e-02,
          -2.2988e-02, -6.3029e-03],
         [-2.0799e-02, -8.0206e-02,  9.8678e-04,  ...,  3.0208e-02,
          -5.2368e-02, -3.8532e-03],
         [-3.3480e-02, -1.0205e-01, -2.9038e-03,  ...,  3.8216e-02,
          -6.9294e-02,  2.7921e-04],
         ...,
         [ 2.4219e-01,  1.2027e-01,  9.9095e-03,  ..., -3.2083e-01,
           4.4621e-01, -4.8287e-02],
         [ 1.0418e-01, -9.3736e-03, -4.9141e-03,  ..., -2.5309e-02,
           5.6293e-01, -6.6296e-03],
         [ 1.5478e-02, -8.4021e-02, -9.3174e-03,  ...,  9.1774e-03,
           2.3124e-01, -2.4856e-03]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 1.5478e-02, -8.4021e-02, -9.3174e-03, -3.3236e-02,  1.7974e-02,
         -2.1405e-02,  6.9059e-02, -1.3891e-01, -1.3575e-02,  1.0402e-02,
         -6.0988e-02,  5.6678e-02,  4.9735e-02,  1.3339e-02,  6.6050e-03,
         -6.7539e-03, -1.4877e-01, -2.3620e-02,  1.1594e-03,  7.2959e-02,
   

tensor([[[-3.2293e-03, -5.3536e-02,  2.4005e-03,  ...,  1.8431e-02,
          -2.1482e-02, -8.8397e-03],
         [-1.9147e-02, -1.0674e-01, -6.8826e-03,  ...,  3.6624e-02,
          -4.8180e-02, -5.3579e-03],
         [-3.1805e-02, -1.5272e-01, -2.0008e-02,  ...,  4.6590e-02,
          -5.8470e-02,  1.8840e-03],
         ...,
         [-3.0458e-02, -1.8470e-01, -1.6960e-02,  ...,  6.4692e-02,
          -5.6354e-02,  8.4982e-05],
         [-3.7112e-02, -2.1343e-01, -4.4129e-02,  ...,  5.7350e-02,
          -4.8542e-02,  1.3301e-02],
         [-3.8382e-02, -2.3593e-01, -8.5303e-02,  ...,  5.6593e-02,
          -3.0739e-02,  3.2873e-02]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0384, -0.2359, -0.0853, -0.2222, -0.1106, -0.1093,  0.2257,  0.2337,
         -0.0489, -0.1875, -0.3829, -0.0363,  0.0254, -0.2597, -0.2330, -0.0286,
         -0.0159, -0.0467,  0.4687, -0.0281, -0.1393,  0.0764, -0.4785,  0.1734,
          0.1561,  0.5792, -0.0662, -0.0660, -0.0526,  0

tensor([[[-4.5275e-03, -5.4664e-02,  2.4558e-03,  ...,  2.2139e-02,
          -2.2908e-02, -9.4980e-03],
         [-2.1862e-02, -1.0355e-01, -3.5890e-03,  ...,  4.0975e-02,
          -5.3804e-02, -5.9149e-03],
         [-3.4787e-02, -1.3346e-01, -8.4170e-03,  ...,  5.0903e-02,
          -6.9251e-02, -2.4011e-04],
         ...,
         [ 4.6070e-02,  5.5152e-02,  1.3891e-02,  ..., -9.5680e-02,
           6.3510e-02, -7.2697e-02],
         [ 2.8226e-02, -4.3334e-03, -2.8790e-03,  ...,  3.3077e-02,
           3.0507e-01, -2.4663e-02],
         [-1.3173e-02, -1.2289e-01, -1.7427e-02,  ...,  4.0323e-02,
           9.2361e-02, -1.7404e-03]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0132, -0.1229, -0.0174, -0.0442,  0.0098, -0.0126,  0.0790, -0.0026,
         -0.0107, -0.0014, -0.0839,  0.0117,  0.0482,  0.0089, -0.0070, -0.0064,
         -0.1512, -0.0540,  0.0021,  0.0600,  0.0472, -0.0480, -0.0366,  0.0267,
          0.0341, -0.0077,  0.0040,  0.1380, -0.0764, -0

tensor([[[-0.0059, -0.0546,  0.0020,  ...,  0.0255, -0.0269, -0.0100],
         [-0.0251, -0.0991, -0.0016,  ...,  0.0444, -0.0648, -0.0061],
         [-0.0393, -0.1156, -0.0016,  ...,  0.0544, -0.0861, -0.0018],
         ...,
         [-0.0236,  0.0099,  0.0033,  ...,  0.0932,  0.1574, -0.0371],
         [-0.0364, -0.1173, -0.0169,  ...,  0.0686,  0.0412, -0.0018],
         [-0.0324, -0.2130, -0.0096,  ...,  0.0582, -0.0348,  0.0008]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0324, -0.2130, -0.0096, -0.0873, -0.0345, -0.0760,  0.0866,  0.0985,
         -0.0255, -0.0382, -0.1000, -0.0123,  0.0267, -0.0236,  0.0158, -0.0152,
         -0.0860, -0.0512,  0.0661, -0.0219,  0.0310,  0.0396, -0.0455,  0.0119,
          0.0627,  0.1379, -0.0487, -0.0348, -0.0741, -0.0013,  0.0130,  0.0723,
         -0.0212, -0.0123,  0.0240,  0.0914,  0.0163,  0.0126, -0.0829,  0.0315,
         -0.0325, -0.0891, -0.0214, -0.0079, -0.0106, -0.0148, -0.1602, -0.0494,
          

tensor([[[-0.0095, -0.0562,  0.0022,  ...,  0.0283, -0.0337, -0.0097],
         [-0.0305, -0.1025, -0.0006,  ...,  0.0493, -0.0754, -0.0054],
         [-0.0458, -0.1187,  0.0005,  ...,  0.0612, -0.0960, -0.0007],
         ...,
         [-0.0550, -0.1423,  0.0069,  ...,  0.0765, -0.0656,  0.0014],
         [-0.0567, -0.1365,  0.0101,  ...,  0.0656, -0.0990,  0.0003],
         [-0.0633, -0.0984,  0.0098,  ...,  0.0686, -0.1184,  0.0002]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-6.3274e-02, -9.8416e-02,  9.7883e-03, -7.2962e-02, -4.6450e-02,
         -4.5073e-02,  2.2196e-02,  9.5259e-02, -5.0076e-02, -5.6828e-03,
         -2.9111e-01, -3.8673e-02,  2.0658e-02, -1.5883e-01, -1.5190e-01,
          1.0902e-02,  2.7127e-02, -2.8027e-03,  3.7068e-02, -5.4246e-02,
         -1.9606e-02,  5.4591e-02, -1.4327e-01,  7.4010e-02,  3.5191e-02,
          6.0249e-02, -9.5296e-02, -1.0550e-01, -3.5449e-02,  7.7034e-02,
         -8.0144e-03,  7.7012e-02, -3.9339e-02, -2.8

tensor([[[-0.0161, -0.0565,  0.0045,  ...,  0.0356, -0.0517,  0.0012],
         [-0.0361, -0.0911,  0.0014,  ...,  0.0619, -0.0990,  0.0039],
         [-0.0478, -0.0939,  0.0064,  ...,  0.0810, -0.1287,  0.0061],
         ...,
         [-0.0734,  0.0294,  0.0377,  ...,  0.0574, -0.3693, -0.0131],
         [-0.0425,  0.0466,  0.0531,  ...,  0.0986, -0.3305, -0.0039],
         [-0.0587, -0.0549, -0.0041,  ...,  0.0899, -0.2697,  0.0071]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-5.8662e-02, -5.4903e-02, -4.1299e-03, -4.9066e-02,  1.0150e-02,
          7.0406e-02,  6.7145e-02,  1.1117e-02, -4.9614e-03,  6.1656e-02,
         -2.7474e-01, -9.7465e-03,  9.5516e-02,  7.7632e-04, -1.3257e-02,
         -1.0957e-02, -8.4724e-02, -5.5728e-02, -1.5372e-02,  1.2748e-02,
          7.0475e-02, -1.1944e-01, -5.9234e-02,  4.0568e-02,  2.8219e-02,
         -7.6157e-02, -6.2546e-02,  4.6590e-02, -6.4310e-02, -1.4599e-03,
          1.8227e-01, -1.3731e-01, -1.8039e-02,  2.9

tensor([[-0.0476, -0.0947,  0.0036, -0.1002,  0.0053,  0.0328,  0.0607,  0.0481,
         -0.0211,  0.0442, -0.2640, -0.0391,  0.0668, -0.0047, -0.0072, -0.0123,
         -0.0115, -0.0225, -0.0037, -0.0330,  0.0195, -0.0485, -0.0704,  0.0135,
          0.0174, -0.0688, -0.0555, -0.0117, -0.0241, -0.0060,  0.1409, -0.0817,
         -0.0232, -0.0101,  0.0048,  0.0216, -0.0350, -0.0350, -0.0313,  0.0511,
         -0.0056, -0.0819, -0.0490,  0.0283,  0.0479, -0.0300, -0.0216,  0.0014,
         -0.0219,  0.0951, -0.0072, -0.0291,  0.0032,  0.0262, -0.0100, -0.0050,
          0.0677,  0.0194, -0.0484, -0.0333,  0.0889,  0.0429, -0.0568, -0.0016,
         -0.0327, -0.0111, -0.0832,  0.0182, -0.0144, -0.0439, -0.0361, -0.0127,
          0.0197, -0.0216,  0.0794, -0.0570, -0.0825, -0.0125,  0.0130, -0.0065,
          0.0076,  0.0892, -0.0911, -0.0439,  0.1395, -0.0126,  0.0397,  0.0145,
          0.0211,  0.0326, -0.1687, -0.0560,  0.0095, -0.0047,  0.0267, -0.0180,
          0.0036,  0.0916, -

tensor([[[-1.4298e-02, -5.1645e-02,  3.1654e-03,  ...,  3.7442e-02,
          -5.5259e-02,  9.2468e-03],
         [-2.7900e-02, -7.9141e-02, -3.2255e-03,  ...,  6.2292e-02,
          -8.8410e-02,  1.1159e-02],
         [-3.2169e-02, -8.1813e-02, -2.0807e-04,  ...,  7.8894e-02,
          -1.0686e-01,  1.3777e-02],
         ...,
         [-2.6507e-02,  1.6175e-02, -8.2636e-03,  ...,  5.8295e-02,
          -2.2520e-01,  1.0503e-02],
         [-2.8572e-02, -9.8976e-02, -1.5060e-02,  ...,  8.5050e-02,
          -1.3244e-01,  9.7981e-03],
         [-2.9110e-02, -1.3528e-01, -8.6842e-03,  ...,  9.7851e-02,
          -1.0774e-01,  8.6631e-03]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-2.9110e-02, -1.3528e-01, -8.6842e-03, -1.5418e-01, -1.5653e-02,
         -1.6092e-02,  6.4163e-02,  6.3717e-02, -2.2628e-02,  1.1829e-04,
         -2.0136e-01, -4.3638e-02,  3.2021e-02, -3.3013e-02, -4.0577e-02,
         -1.1514e-02, -2.7669e-02, -2.1225e-02,  4.7306e-02, -2.4812e-02,
   

tensor([[[-0.0086, -0.0366,  0.0025,  ...,  0.0323, -0.0507,  0.0099],
         [-0.0184, -0.0522, -0.0047,  ...,  0.0495, -0.0778,  0.0121],
         [-0.0197, -0.0520, -0.0032,  ...,  0.0596, -0.0928,  0.0148],
         ...,
         [ 0.0016,  0.1031,  0.0178,  ...,  0.0494, -0.2324,  0.0214],
         [-0.0102,  0.0885, -0.0037,  ...,  0.0455, -0.1874,  0.0128],
         [-0.0143, -0.0039, -0.0136,  ...,  0.0569, -0.1202,  0.0106]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0143, -0.0039, -0.0136, -0.0445,  0.0092,  0.0846,  0.0750,  0.0173,
          0.0080,  0.0884, -0.0592, -0.0262,  0.0423,  0.0163,  0.0139, -0.0056,
         -0.0555, -0.0348,  0.0016,  0.0103,  0.0415, -0.1666, -0.0677,  0.0121,
          0.0188, -0.0821, -0.0200,  0.1660, -0.0121, -0.0209,  0.1527, -0.1131,
         -0.0078,  0.0255,  0.0477,  0.0429, -0.0245, -0.0099, -0.0376,  0.0177,
         -0.0219, -0.0445, -0.0150,  0.0458,  0.0042,  0.0066, -0.0608,  0.0462,
         -

tensor([[[-0.0051, -0.0226,  0.0028,  ...,  0.0287, -0.0510,  0.0102],
         [-0.0128, -0.0256, -0.0043,  ...,  0.0416, -0.0787,  0.0127],
         [-0.0126, -0.0187, -0.0032,  ...,  0.0484, -0.0953,  0.0156],
         ...,
         [-0.0756,  0.0607,  0.0097,  ...,  0.0208, -0.2755,  0.0182],
         [-0.0743,  0.0743,  0.0086,  ...,  0.0214, -0.2996,  0.0166],
         [-0.0432,  0.1288, -0.0064,  ...,  0.0285, -0.3123,  0.0168]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-0.0432,  0.1288, -0.0064,  0.0073,  0.0267,  0.1931,  0.0441, -0.0034,
          0.0213,  0.3202, -0.1701, -0.0281,  0.0453,  0.0299,  0.0273,  0.0153,
          0.0609, -0.0378, -0.0094, -0.0234,  0.0024, -0.2660, -0.0561, -0.0058,
         -0.1217, -0.1027, -0.0771, -0.0281,  0.0728, -0.0303,  0.2213, -0.3942,
         -0.0022,  0.0162,  0.0228,  0.0144, -0.0521, -0.0292, -0.0073,  0.0386,
          0.0073,  0.0199, -0.0013,  0.0964,  0.0249, -0.0283,  0.0434,  0.1371,
         -

tensor([[[-2.7840e-03, -8.2335e-03,  2.2334e-03,  ...,  2.7033e-02,
          -5.2076e-02,  1.0445e-02],
         [-9.2704e-03,  1.3045e-03, -4.7355e-03,  ...,  3.8493e-02,
          -8.2396e-02,  1.3528e-02],
         [-8.1129e-03,  1.8176e-02, -3.9894e-03,  ...,  4.4109e-02,
          -1.0271e-01,  1.6717e-02],
         ...,
         [ 1.0321e-02,  2.1108e-01, -2.6215e-03,  ...,  2.5283e-02,
          -2.8791e-01,  1.4626e-02],
         [ 6.4609e-04,  1.9441e-01, -1.3539e-02,  ...,  3.0993e-02,
          -1.8967e-01,  1.1731e-02],
         [ 2.0500e-04,  1.3771e-01, -1.6465e-02,  ...,  4.1440e-02,
          -1.3020e-01,  1.1653e-02]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 2.0500e-04,  1.3771e-01, -1.6465e-02, -1.8094e-02,  1.6834e-02,
          1.1964e-01,  5.7356e-02,  2.8902e-03,  1.1043e-02,  1.1339e-01,
         -2.6934e-02, -1.2416e-02,  2.5782e-02,  3.0825e-03,  1.0746e-02,
          9.6534e-04, -6.3827e-02, -4.0570e-02, -9.7179e-03,  1.9762e-02,
   

tensor([[[ 1.1669e-03, -9.0295e-03,  1.4565e-03,  ...,  2.7994e-02,
          -3.9154e-02,  1.0004e-02],
         [-3.4953e-03, -4.8947e-03, -5.8480e-03,  ...,  4.0384e-02,
          -5.6633e-02,  1.2587e-02],
         [-9.9429e-04,  2.4783e-03, -5.7661e-03,  ...,  4.6901e-02,
          -6.4811e-02,  1.4692e-02],
         ...,
         [ 1.2152e-02,  2.9751e-01, -6.9247e-03,  ...,  2.4293e-02,
          -2.4896e-01,  2.4048e-02],
         [ 1.2066e-02,  3.0804e-01, -1.7612e-02,  ...,  2.2906e-02,
           9.1833e-03,  1.1588e-02],
         [-1.2874e-04,  9.8437e-02, -2.3381e-02,  ...,  3.3847e-02,
           1.9461e-02,  6.1641e-03]]], grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[-1.2874e-04,  9.8437e-02, -2.3381e-02, -1.2776e-02,  1.8152e-02,
          6.1677e-02,  7.3157e-02, -8.1546e-03,  8.3682e-03,  6.5867e-02,
          1.2507e-02, -2.7900e-03,  1.6323e-02,  2.2289e-03,  4.2387e-03,
          2.8475e-02, -1.8833e-01, -5.9484e-02, -1.8173e-02,  6.7972e-02,
   

tensor([[[ 0.0009, -0.0153,  0.0008,  ...,  0.0313, -0.0367,  0.0099],
         [-0.0032, -0.0184, -0.0064,  ...,  0.0463, -0.0524,  0.0122],
         [-0.0005, -0.0183, -0.0064,  ...,  0.0549, -0.0595,  0.0139],
         ...,
         [-0.0189,  0.1153, -0.0139,  ...,  0.0350, -0.1442,  0.0204],
         [ 0.0252,  0.2213,  0.0025,  ...,  0.0500, -0.0608,  0.0161],
         [ 0.0128,  0.1138, -0.0167,  ...,  0.0570,  0.0128,  0.0098]]],
       grad_fn=<TransposeBackward0>) torch.Size([1, 28, 128])
tensor([[ 1.2777e-02,  1.1377e-01, -1.6741e-02,  1.6178e-02,  2.7639e-02,
          1.4836e-01,  6.0328e-02, -3.4484e-03,  2.3845e-02,  8.8721e-02,
         -2.1493e-03,  4.4291e-03,  3.7688e-02,  3.2256e-03,  1.3914e-02,
          7.4407e-02, -1.3725e-01, -6.6162e-02, -3.9236e-02,  7.6152e-02,
          9.8077e-02, -2.2135e-01, -8.1213e-02,  2.1231e-02,  4.1257e-02,
         -1.7529e-01,  8.3133e-02,  4.4871e-01, -5.5699e-02, -2.1296e-02,
          2.1302e-01, -1.1650e-01,  8.5725e-03,  1.8

KeyboardInterrupt: 

In [None]:
# 7. Test
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, sequence_size, input_size))
    
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += len(predicted)
    correct += (predicted == labels).sum()
    
print("accuracy:", correct.item()/total)