In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = NeuralNet(input_size, hidden_size, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

In [25]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        print(outputs[:,0])
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

tensor([ 1.1014e+01, -2.9951e+00, -1.5285e+00, -3.7712e+00, -2.7242e+00,
        -4.4759e+00, -6.9170e+00, -6.7646e+00, -4.5772e+00, -5.1280e+00,
        -5.0609e+00, -4.5529e+00, -9.2128e+00, -9.4913e+00, -3.3312e+00,
        -3.1143e+00, -3.1803e+00,  1.1503e+01, -5.4753e+00, -7.9168e+00,
        -9.5003e+00, -4.0110e+00, -5.1246e+00,  2.7731e+00, -4.7944e+00,
         9.5319e-03, -3.9997e+00, -2.7452e+00, -4.3944e+00, -7.6632e+00,
        -5.5096e+00,  1.4400e+01,  1.0504e+01,  1.0781e+01, -5.2989e+00,
        -2.2540e+00, -9.4116e+00, -4.6377e+00, -2.7919e+00, -5.2862e+00,
        -1.0479e+01, -8.6635e+00,  1.1063e+01, -5.3389e+00, -5.1664e+00,
         8.1792e+00, -5.6235e+00, -8.8688e-01, -5.0780e+00, -7.4639e+00,
        -7.3063e+00, -4.3005e+00, -3.7366e+00, -4.3964e+00, -1.9899e+00,
        -5.6418e+00,  7.1291e+00,  1.0333e-01,  7.4733e-01, -7.4196e+00,
         5.2582e+00, -3.2838e+00, -6.4246e+00, -9.2048e+00, -3.3543e+00,
        -2.8614e+00, -2.7982e+00, -3.4293e-02, -3.8

tensor([ -8.5991,   6.8040,  -6.1631,  -6.4326,  -7.6933,  -2.9412,  -2.8803,
         -8.4113,  13.5586,  -7.5850,  -1.7859,  -3.8247,  -8.4105,  -1.0314,
         -3.5782,  -2.0462,  -7.0407,  -7.6459,  -2.8229,  -4.8908,  -5.0016,
         -8.9533,  -4.0895,  -3.8617,  -6.7481,  -9.1776,  -4.7086,  -0.3064,
         -5.9564,  -6.0871,  -3.9873,  -3.1378,  -5.4598,  -2.1402,  -9.9481,
         -5.7999,  -5.4354,  -6.9927,   7.6487,  -7.6139,  -3.8711,  -5.5261,
         -6.7108,  -6.7535,  -2.5138, -11.2620,  -4.1698,  -5.1008,  -7.9972,
         -9.8002,  -6.3719,  -9.0416,  -2.2841,  10.7072,  -0.2954,  -9.2838,
         -0.7896, -10.7289,  -4.6089,  -6.3848,  -4.2321,  -6.2324,  -4.8924,
         -8.2210,  -5.8960,  14.6761,   8.4039,  -9.0141,  -5.2282,  -7.0043,
         -3.6909,  -2.2197,  -5.6610,  -4.6301,  -7.3202,  -8.7212,  -7.3594,
         -7.2956,  -6.5922, -10.1943,  -1.8025,  -4.9096,  -4.3143,  -7.6383,
         -7.7760,  12.1170,  -9.9096,  -7.5805,  -7.1590,  -3.94

tensor([ -9.1201,  -1.1615,  -9.9387,  -5.7588,  -8.7000,  -7.6425,  -5.0804,
         -9.4109,  -4.4824,  -5.3696,  -8.5189,   7.5344,  -3.2855,   9.9323,
         -2.1436,  10.2479,  -3.2901,  -3.6909, -10.0256,  -6.7957,  -3.5949,
         -7.8055,  -5.8229,  -7.5951,  -9.3786,  -6.9725,   7.3895,   1.3473,
         -4.5069,  -5.0969,  12.1883,  -4.4892,  -8.2658,  -5.0427,  -6.1740,
         -5.6635,  -1.5642, -10.6010,  -4.9165,   7.2806,  -7.1473,  -8.1943,
         -9.3816,  -4.5277,  -7.0662,  -5.6788,  -3.4984,  -3.8833,  11.6257,
         -4.5429,  -2.5953,  -7.1074,  -7.1031,  -2.8007,  -1.1248,  -5.6868,
         -9.9334, -10.7654,  -0.5997,  -8.3475,  -4.8026,  -6.8399,  -7.0702,
         -3.7170,   1.2869,  -7.0710,  -8.9018,  -3.9597,  -7.5962,  -7.7962,
         -2.6196,  -4.7930,  -9.5777,  -3.5042,  -2.4196,  -3.0946,  -4.7999,
         -7.1697,  -8.7346, -12.3988,  -9.1917,  -3.6369,  -7.2214,  -1.4925,
         -7.8078,  -3.2533,  12.1182,  -5.5075,  -6.6859,  -0.51

tensor([  8.6449,  -7.4024,  -3.2901,  -5.3506,  11.2958,  -4.3347,  -7.2866,
          2.8508,  -7.6339,  -9.5861, -15.5690,  -9.2715,  -7.3266,  -6.8294,
         11.6658,  -3.7393,  -4.2969,  11.9028,   8.3088,  -4.5704,  -5.8673,
         -5.4429,   7.3404, -11.3269,  -5.0719,  -3.8672,  -4.1967,   4.1420,
          0.8768,  -8.7409,  -7.4165,  -7.2114,  -4.5466, -11.3251,  -8.0373,
         -7.3933,  -1.3143,  -8.0379, -10.0981, -10.6712,  -4.0782,  -7.2912,
         -7.4224, -11.3494,  -8.7579,  -2.9709,  -4.0858,  -9.1575,  -3.3576,
         -7.0270,  -3.4490,  11.4143,  -6.7103,  -7.1096,  -3.5088,  -8.1641,
         -5.8751,  -5.5016,  -5.1336,  -5.6972,  -6.1425,  -4.4106,  -5.5868,
         -6.7134,  -5.2348,  -3.7512,  -6.6533, -11.1857,  -5.8066,  -5.0281,
        -10.1199,  -2.0225,  12.6526,  -7.2896,  -6.1858,  -3.0631,  -8.0411,
        -10.3787,  -6.0865,  -7.0167,  -8.5529,  -6.7414,  -6.7829,   0.3893,
          4.5811,  -5.4002,  -4.1539,  -6.1547,  -7.8712,  -3.97

tensor([ -5.6372,  -2.2478,  11.6124,  -9.2461,  -2.6845,  -4.7306,  -4.7014,
         -1.9924,   0.2151,  10.0356,  -6.5984,  -6.5975,  -2.9258,   0.9408,
         -6.3018,  -7.7365,  -4.1215,  -6.4216,  -4.8098,  17.7318,  -3.8287,
         -4.8151,  -4.2619,  -5.1764,  -2.8673,  16.9134,  -6.4701,   1.0108,
         -8.7973,   0.0201,  -7.5904,  -9.1958,  -5.6346,  -3.6289,  -6.7558,
         -1.9282,  -3.3723,  -4.0354,  -1.6907,  -4.9780,   2.5047,   9.6241,
        -10.0273,  -6.7097,  -2.8943,  -7.9315,  -7.4551, -10.3482,  -2.9399,
         -4.6871,  -4.3796,  -0.2455,  -2.7296,  -2.2511,  -2.3088, -11.7175,
          0.1518,   1.5767,  -6.0943,  -4.4235,  -5.2608,  -1.0366,  -5.7574,
         -7.4350,  -1.4885,  -7.3377,  -3.1581,  -5.9799,  12.3591,  -1.7434,
         -9.3944,  -4.6585,  -4.2084,  -5.1886,  -4.3038,  -3.7364,   1.0025,
         -2.7028,  -4.5248,  -2.5834,  -6.5906,  -7.0931,  -5.3895,  -3.1956,
         -6.0256,  -2.9739,  -6.1313,  -2.1001,   5.3483,  -8.22

tensor([ -8.5771,   6.6546, -10.6704,  -3.4006,  -5.3622, -10.4238,  10.5339,
         -5.0790,  -5.5417,  -4.6196, -10.0771,  -4.2578,  -5.3513,  -5.0839,
         -2.1243,  18.5280,   8.7124,  -4.9175,  -6.6266,  -4.5302,   6.7145,
         -9.7682,  -7.4201,   0.0662,  -5.5828,  -5.3757,  -9.8205,  -7.3878,
         10.3386,  -5.5903, -10.3970,  -6.1176,  -5.7899, -10.2711,  -7.2449,
         -9.5792,  -6.4584,  -7.7261,  -3.4285,  -4.5837,  -5.5349,  -8.8711,
         -3.6661,  -6.5195,  -1.8324,  -4.3690, -10.6100,   1.0857,  -6.4425,
        -10.3039,  -8.7881,   7.8670,  -6.0187,  -7.0313,  -5.9117,  -7.0241,
        -10.4359,  -9.6906,  -4.2134,  -4.1320,  -4.8863, -10.5608,  -6.6533,
         -8.9846,  -6.5882,  -4.0888,   8.8381,  -7.0526,  -8.1756,  -6.2307,
         -7.7078,  -7.7362,  -8.9594,  11.7719,  13.1246,  -3.2017,  -3.5826,
         -7.5033,  -8.8196, -11.0987,  -3.9942,  -7.8447,  -7.5711,   4.7023,
         -2.5696,  -6.7541,  -5.4349,  -8.6784,   9.2340,  -5.41

tensor([ -3.3257,   0.4416, -11.2042,  -7.9736,  -5.0162,   6.6043, -10.2667,
         -4.9019,  -4.7660,  -7.4414,   7.0879,  -3.6179,  -3.6549,  -6.6904,
         -2.3947,  -8.0752,  15.3835,  -8.7609,  -5.5888,  -7.9648,  -8.0256,
         -6.0671,  -6.2865,  -6.8439, -11.0959,  -6.2246,  -5.5885, -11.6760,
         -8.4720, -11.7827,  -9.2192,  -8.2048,  -5.7391,  -9.2666, -10.1206,
         -6.8159,  -4.9795,  15.6964,  -3.5611,  -6.6406,  -6.8534,  -4.5091,
         -3.8030,  -2.6580,   9.9194,  -1.5759,  -4.2390,  -6.4965,  -6.0496,
          0.0407,   1.3635,  -2.8092,  -1.4125,   7.4157,  -8.3403,  -9.4418,
         -5.5471,  -8.9797,  -8.8016,  -7.8767,  -9.4935,  -8.9414,  -5.9053,
         -0.3944,  -9.6569,  -8.3666,  -5.7997,  -6.4431,  -5.5423,  -5.8688,
         -6.7305,  10.7167,   9.4571,  -4.1145,  -5.8223,  -4.6274,  -5.8782,
          1.5178,  -9.6184,  -3.9103,  -1.2905,  -7.3332,  -7.3930,  -5.1869,
         -9.5543,  -8.1789, -12.5427,  -6.4571,  -5.8676,  -7.92

tensor([ -7.2793,  -6.0067,  -5.2924,  -7.1064,   5.8251,  -0.7158, -11.1751,
         -9.0666,  -4.5725,  -9.7569,  -9.9314,   0.9497,  -1.0509,  -2.1962,
          7.1550,  -6.5495,  -7.1072,  -7.0709,   9.8710,  -7.1404,  -4.9745,
         -1.6691,  -2.3036,  -5.4322,  -2.3555,  17.0309,   7.5050,   0.7361,
         -2.3455,  -3.8052, -10.0083,  11.5472,  -6.0697,  -7.2905,  -8.6554,
         -7.1177,  -2.4714,  -0.3883,  -4.1634,  -4.1546,  -5.6210,  -7.7241,
         -1.4370,  -7.1149,  -0.3327,  -1.7319,  -7.3483,  -5.7567,  12.7726,
         -7.9135,   7.5387,  -0.4044,  -1.9949,  -3.6207,  -5.3018, -11.6581,
         -2.9495,   1.4969,   7.9319,  11.6926,  -4.6873,  -4.5692,  -5.5623,
         -4.2740,  -5.1775,  -6.4276,  -9.0732,  -4.4758,  -6.3633,  -3.1939,
         -3.8236,  10.2220,  -6.7816,  -4.9070,  -3.0603,   0.3043,  -5.4886,
         -1.1904, -12.3962, -10.4683,  -3.1377,  -5.2180,  -4.5262,   2.2224,
         -4.6809,  -4.7172,  -7.7201,  -8.0863, -12.0364,  -2.20

tensor([ -5.1820,  -4.5998,   5.8805,  -8.2414,  -3.5212,  -1.7946,   9.3695,
        -11.1407,  -8.4530,  -6.3904,  -7.5393,  -5.8290,  -5.5756,  -3.7935,
         -2.4167,  -6.8824,  -6.5275,  -3.6921,  -7.3212,  -1.9412,  -3.2343,
         -5.6632,  -6.6207,   6.7037,  -7.4579,  13.1405,  -7.5776,  -7.4533,
         -2.5883,  -6.3249,  -2.3730,   7.0674,  -9.7064,  -4.8938,  -7.1654,
         -5.7182,  -7.0556,  -2.0405,  -1.6644,   7.8620,  -3.7531,  -4.1067,
         -7.1574, -11.0898,  -6.9228,  -5.4797, -10.5052,  -1.7086,  -5.0205,
         11.6099,   8.7726,  -0.6260,  -5.6623,  -6.7900,   5.9570,  -5.7181,
          7.1562,   8.9253,  -4.0332,  10.8201,   9.2870, -10.4044,  -5.3164,
         -6.5018,  12.3069,  -3.4465,  -3.5710,  -5.8808,  -5.1846,  -6.0122,
         -8.4002,  -3.9440,  -5.6066,  -8.3660,  -5.9183,  -6.7330,  -5.6013,
         -4.1574,  -8.7179,  -8.5292,  -3.2420,  -3.4701,  -8.5936,  -2.6125,
        -10.0560,   0.5130,  -7.8951,  -4.7551,  -5.3271,  -7.49

tensor([ -6.2337,  -4.6542,  -9.3144,  -3.3646,  -4.9613,  -8.0232,  -6.0314,
         -9.1798,  -4.1260,  -8.0468, -13.2457, -10.5422,  -5.3078,  -3.7690,
         -6.3096, -11.1362,  -5.7947,  -8.0975,  -1.8938,  -5.6409,  -9.1397,
         -6.1978,   0.7204,  -2.0420,   9.0563, -11.7309,  -3.8407,   3.6009,
         -5.6932,  -4.2420,  -9.4889,  -4.8972,  -9.2896,  -2.2377,  -8.0583,
         -4.4918,  -3.9275,  -6.1294,  -6.6066,  -9.0773,  -1.9906,  -4.0945,
         -7.6919,  -5.1719,  -8.2126,  -5.1753,  -7.0605,  -9.4793,  -8.8075,
        -13.1909,  -5.9376,  -5.2037,  -8.2302, -12.7402,  -5.6636,  -3.9491,
         -3.5634,  -3.0129,  11.2141,  -9.2763,  -2.3508,  -6.6687,  -1.5706,
         -6.6566,  -5.6835,  -6.3421, -10.8962,  -7.9055,  -7.3988,  -4.7419,
         -8.8214,  -6.4616,  -7.3676,  -6.8979,  -1.6143,  -4.7798,  -5.2027,
         -6.9340,  -2.2962,  -6.2149,  -6.4560, -12.8862,  -4.8759,  -2.7945,
         -3.6745,  -2.5617, -12.1591,  -4.4902,  -6.0043,  -7.44

tensor([ -3.7889,  -1.4124,  -2.2203, -10.9318,   8.7920,   5.1172,   9.4071,
          9.9353,  -8.8555,  -7.0276,  -9.4528,  -2.8589,   8.6776,  -5.7172,
         -6.8282,   8.6208,  -7.5609,  -8.1268, -10.4645,   4.0480,  -0.4901,
         -9.7169,   7.0777,  -2.7692,  -8.5156,  -1.6284, -12.0630,  -6.0254,
         -8.1439,  -7.2891,  -9.2130,  -5.4199,   5.0174,  -4.4586,  -6.2939,
         -3.6968,  -9.9674,  -9.5727,  -0.5832,  -9.4632,  -9.5369,  -7.8664,
         -6.7555, -10.2316,   3.0460,  -4.3009, -12.1996,   1.0178, -10.1402,
        -14.5285,  -8.0082,  -6.9646,  -8.3560,  -8.8605,  -5.1049,   5.2411,
         -9.2338,  -2.1881,   3.7630,  11.5505,  -0.9127,  -6.3356,  -2.1906,
         -4.1433,  -7.5376,  -1.9758,  -9.8665,  -8.9675,  -4.2726,  -5.2704,
         -6.5053, -10.1584,  -1.0193,  -3.2108,   1.1095,   5.5888, -10.8565,
        -11.9329,  -4.5565,  -6.0821,  -3.7249, -11.9360,  -9.8036,  -6.7446,
          9.9677,  -0.8980,  -5.4245,  -4.3968,  -5.4280,  -5.72

tensor([ -7.6630, -12.4543,  -7.3957,  -4.9727,  -5.6066,  -7.9433, -10.9647,
         -1.7342, -10.0817,  -5.6675,  -7.3207,  -5.0516,  -6.3213,  -5.5686,
         -9.0612,  -2.9415,  -5.7521,  -6.2167,  -6.1707,  -3.6594,  -5.7020,
         -9.9221,  -7.9793,   3.6248,  11.9166,  -6.5455,  -9.4961,  -5.3632,
        -13.5794,  -5.8013, -10.9852,  -5.0319,  -3.2679,  -6.7246,   8.7229,
         -4.9026,  -3.3643,  -4.1284,  -3.3610,  -0.8512,  -2.6142,  -2.8292,
         -8.8242,   7.3313,  -5.8173,   7.3985,  -1.7866,  -4.7853,  -7.7137,
         -6.5437,   6.2514,   9.7317,  -7.1257,  -5.7932,  -9.7217, -15.6534,
         -6.0219,   6.7311, -11.3187,  -6.5630,  -0.7793,  -6.7948,  -7.1452,
         12.1939,  -7.5938,   0.1026,  -4.0302,  -1.7379,  -7.9509,  11.9332,
        -13.1466,  -5.3637,  -3.6546,  -2.9362,   8.6263,  10.0325,  -6.0358,
         -2.4728,   0.4615,   7.0337,  -3.0696,  -4.6296,  -1.5658,  -0.7973,
         -4.6503,  -5.7615,  -4.7038,  -2.0262,  -2.5167,  -4.61

tensor([ -1.7116,  -5.5044,  -8.5142, -13.1332,  -1.2947,  10.3418,  -3.4904,
         -7.4361,   1.7488,  -4.6047, -11.0991,  -4.4133,  -6.9887,  -2.4623,
        -15.5867,   1.4629,  -3.3410,  -9.7543,  -4.7643,  -5.5557,  -3.0276,
         -5.6854,  -8.2203,  13.5184,  12.5843,  -5.7417,  -8.9760,  -7.1961,
         -4.8549,  -9.2229,  11.2329,  -6.4049,  -4.0424,  -6.9866,   6.5530,
         -8.4164,  -9.7069,  -4.9526,  -2.2706, -12.3247,  -7.1891,  -8.2055,
         -8.2843,  11.7730,  -4.6758,  -3.8995,  -6.8997,  -5.7162,  -6.0785,
          5.9128,  -0.8037,  -5.8162, -11.9928,   7.0630,  -8.4834,  -6.6687,
         -6.3039,  -5.7561,  -6.4685, -11.0652,  -7.9671,  -2.8538,  -9.1251,
         -7.6031,  -2.4233,  -1.4976,  -4.2431,  -0.9290,  -5.5252,  -8.9265,
         -7.2809,   3.8082,   7.7420,  -4.4021,  -6.5112,   8.9956,  -5.3312,
         -8.4294,   7.6931,  -6.8839,  -4.3413,   5.1706,  -4.9558,  -3.4868,
          9.1999,  -2.7845,  -2.3995,   7.4011,  -6.0550,  -0.01

tensor([  7.2616,   0.4725, -10.7830,   8.0643,  -7.0201,  -8.5529,  -7.9523,
         -3.5859,  -5.2442,  -3.5475,  -8.0597,  -5.1682,  -7.7267,  -6.2104,
         -7.9924,  -6.0823,  -7.6597,  -3.7516,  -0.8378,  -6.1549,  -0.7779,
         -3.1574,  -9.5766,  -9.9001,  -7.4658,  -4.0801,  -6.8532,  -6.3311,
         -8.5711,  -6.1279,  -9.0399,  -4.5163,  -5.2682,  -9.0678,  -7.3400,
         -3.9782,  -3.7889,  -7.1427,  -6.0493,  -3.2721, -10.6424,  -6.9012,
         -5.0357,  -6.1496,  -0.9036,  -4.7703, -11.7185,  -0.6550,   5.2423,
         -3.0682,  -4.9397,  -6.5849,  -4.0994,  10.6398,  -6.0460,  -2.3296,
        -10.9616,   8.3936,  -6.8438,  -6.3946,  -7.7882,  -5.7333,   8.5564,
         -7.0720,  -6.0572,  -8.9620, -10.1728,  -5.3714, -12.4203,  -9.7834,
         -5.2041,  -7.5626,   6.2890,  -9.2948,  -7.2379,  -5.1574,  -3.4742,
         -8.6944, -10.3200,  -4.0579,   6.0263,  -1.3024,  -9.5622,  -0.5175,
         -7.7716,  -7.2068,  -5.1240, -11.4097,   0.2860,   1.88

tensor([ -4.5604,  -7.7540,  -7.6894,  -7.3142,  -7.5533,  -3.7887,  -5.6549,
        -13.8052,  -7.6725,  -4.1574,  -7.3341, -13.3098,  -2.3615,  -1.9585,
         -5.9550,  -5.7547,  -6.8165, -12.0796,  -8.2859,  -1.6996,  -3.0088,
         -5.5317,  12.5330,  -4.6592, -11.3566,  -3.0236,  -1.2617,  12.3971,
         -9.2289,  -5.3433,  -7.9556,  -1.1756, -10.0246,  -6.0965,  -5.6053,
         -8.1860,  -7.2464, -10.2257,  -7.7955, -10.1010,  -4.5501,  -1.4063,
         -4.5157,  -7.2319, -10.0326,  -6.3179,  -4.5964,  -7.8444,   7.7723,
         -8.1047,   0.7619, -11.4361,  -7.5755,  -9.0800,  -6.5008, -11.7087,
         -6.1161,  -1.5509, -11.4691,  -5.4111,  -4.3817,  -4.8315,  -7.4775,
         -3.1496,  -6.1178,  -6.4467,  -8.2151,  -4.9328,  -0.3560,   1.4441,
         -2.1744,  -7.7345,   8.1064,  -5.6763,  -6.9573,  -3.1126,  -9.1284,
         -3.8811,  -6.1166,   3.4596, -11.3352,  -8.5692, -11.9076,  -9.9529,
        -13.8427,  -9.3811,  -8.8080,  -3.3453,  -6.6926,  -4.55

tensor([ -5.0947,  -2.5028,  -5.6702, -10.2434,   7.5276,  -6.7485,  -5.6049,
         -6.5652,  -4.2876,  -2.7715,  -3.6639, -10.6340,  -6.2108,  -7.4154,
         11.1838,  -0.8900,  -2.2323,  11.8349,  -3.7625,  -3.3550,  -4.4969,
        -12.6594,  11.4380,  -4.2291,  -2.7456,  10.7927,  -4.9350,   9.8371,
        -11.9103,  -6.7003,  -1.7688,  -4.4117,  -3.2028,  -6.0176,  -5.3471,
         -1.9572,  -4.8549,  -5.1149,  -8.6294, -10.0731,  -3.5265,  -7.9628,
         -4.2715,  -6.3883,  -3.2768,  -8.6036,  -4.5958,  -4.5285,  -9.1474,
         -2.1169,  -8.9178,   0.8016,  -7.1440,  -8.0658,  -9.0173,  -1.3927,
         -7.8178,  -3.8789,  -7.2559,  11.3723,  -3.6917,  -7.0832,  -1.8505,
          6.9027,   0.3149,   4.2997,  -3.5420,  -3.1225,  -5.0554,   0.5232,
        -11.9458,  -6.8947,  -5.2608,  -7.8922,  -4.8916,  -0.3170,  -5.8734,
         -7.4239,  -3.2063,  -4.8153,  -9.2692,  -2.1824,  -5.8315,  -5.0802,
         -6.8972,  -2.9701,  14.6103,  -4.7395,  -2.8853,  -8.72

tensor([ -6.7719,  11.1734,  -5.6798,  -5.7516,   0.6417,  -3.2609,  -3.3511,
         -9.8282,  -6.6434, -12.9031,  -8.6759,  -0.6922,  -8.9917,  -8.9464,
         -9.9941,  -3.0949,  -8.9145,  -9.0990,   0.6858,  -6.4528,  -2.9076,
         -6.2643,   0.6570,  -5.3450,  -1.2708,   7.7413,  -2.1431,  -4.1857,
         -5.4533,  -5.7940,  -2.3014,  -2.7843, -12.2617,  -3.4667,  -9.3752,
         -4.9416,   8.5001,  -2.8158,   6.6066,  -5.6643,  -1.9255,  -4.7880,
         -6.8748,  -3.9658,  -2.6259,   0.0903,   8.2573,  -3.3053,  -7.6870,
         -0.5747,  -0.4554,  -1.7929,  -5.4882,  -0.1457,  -5.2630,  -3.1876,
         -4.3880,  -1.7655,  -3.9350,  -8.8603,  -6.9239,   0.0490,  -6.1298,
         -8.5732,   5.2003,   9.7904,  -2.8868, -13.4257,  -5.5801,  -2.9598,
         -4.1445,   5.9435,  10.4330, -11.1268,  -3.7603,  -3.7139,  -6.1873,
         -9.3809,  -1.1713,   0.1572,  -7.1110,  -3.2048,  11.4041,  10.1250,
         -0.1547,   8.0180,  -8.3882, -10.5834,  -6.7424,  -8.82

KeyboardInterrupt: 