In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
import torch
x = torch.rand(5, 3)
print(x)

tensor([[0.6931, 0.1476, 0.1960],
        [0.9572, 0.0783, 0.6816],
        [0.9648, 0.8105, 0.5239],
        [0.6730, 0.5455, 0.4506],
        [0.5758, 0.6720, 0.2585]])


In [3]:
import torch
torch.cuda.is_available()

False

In [4]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [5]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [6]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [8]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [9]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [10]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
tensor([[-0.0706, -0.0615, -0.0502,  0.0008, -0.0235, -0.0705, -0.0473, -0.0443,
          0.0382, -0.0224],
        [-0.0532, -0.0850, -0.0412, -0.0248, -0.0559,  0.0201, -0.0593, -0.0343,
          0.0826, -0.0057],
        [-0.0309, -0.0555,  0.0070, -0.0306, -0.0077,  0.0190, -0.0400, -0.0008,
          0.0645,  0.0024],
        [-0.0506, -0.0684, -0.0192, -0.0069, -0.0135,  0.0181, -0.0254, -0.0090,
          0.0855, -0.0043],
        [-0.0384, -0.1089, -0.0475, -0.0101, -0.0202,  0.0214, -0.0782,  0.0015,
          0.0945, -0.0170],
        [-0.0862, -0.0670, -0.0536, -0.0373, -0.0111,  0.0043, -0.0509, -0.0531,
          0.0887, -0.0274],
        [-0.0158, -0.0785, -0.0097, -0.0283, -0.0071,  0.0094,  0.0198,  0.0041,
          0.0941, -0.0271],
        [-0.1230, -0.0645, -0.0850, -0.0319, -0.0041,  0.0194, -0.0426, -0.0644,
          0.1131, -0.0147],
        [-0.0171, -0.0473,  0.0027, -0.0424,  0.0118,  0.0158, -0.0281, -0.0337,
       

tensor([[-0.0042, -0.0463,  0.0240, -0.0481,  0.0070,  0.0097, -0.0033, -0.0018,
          0.0451, -0.0503],
        [-0.0435, -0.0609,  0.0083, -0.0580, -0.0070, -0.0206, -0.0279, -0.0422,
          0.0512, -0.0398],
        [-0.0225, -0.0826, -0.0149, -0.0213, -0.0025, -0.0357, -0.0672, -0.0338,
          0.0564, -0.0138],
        [-0.0365, -0.0579,  0.0151, -0.0353,  0.0236, -0.0346, -0.0049, -0.0447,
          0.0625, -0.0232],
        [-0.0719, -0.0857, -0.0318, -0.0595,  0.0268, -0.0056, -0.0080, -0.0248,
          0.1209, -0.0178],
        [-0.0092, -0.0882, -0.0023, -0.0336,  0.0202, -0.0026, -0.0200,  0.0163,
          0.0814,  0.0183],
        [-0.0697, -0.0958, -0.0232,  0.0094,  0.0241, -0.1044, -0.0617, -0.0806,
          0.0402,  0.0394],
        [-0.0899, -0.0727, -0.0399, -0.0783,  0.0291, -0.0532, -0.0030, -0.0274,
          0.1004, -0.0027],
        [-0.0433, -0.0757, -0.0019, -0.0697,  0.0229, -0.0279, -0.0022, -0.0617,
          0.0937, -0.0337],
        [-0.0105, -

tensor([[-6.7164e-02, -8.0041e-02, -2.3350e-02, -5.6985e-02,  1.2827e-02,
         -8.9423e-02, -4.1963e-02, -5.9741e-02,  6.9930e-02,  6.3737e-02],
        [-4.0196e-03, -9.3299e-02,  1.9703e-02, -3.0717e-02,  6.7947e-02,
         -1.1349e-01, -2.1273e-02, -1.2251e-01,  6.3727e-02, -3.6366e-02],
        [ 5.0999e-02, -9.4945e-02,  1.4682e-02, -9.1865e-03,  2.9278e-02,
         -1.5827e-01, -6.3574e-02, -1.2907e-01,  1.5031e-02, -3.3542e-02],
        [-3.4374e-02, -6.3139e-02, -1.8655e-03, -5.4932e-02, -5.2310e-03,
         -1.5299e-02, -3.9445e-03,  2.8027e-02,  8.2846e-02, -6.3588e-03],
        [-4.4670e-03, -7.9982e-02,  4.0320e-02, -3.6982e-02,  8.9553e-02,
         -8.5667e-02,  7.5714e-03, -6.8692e-02,  5.7399e-02,  2.4187e-02],
        [ 2.3629e-02, -6.7671e-02,  2.4373e-02,  3.8593e-02,  6.1063e-02,
         -1.2552e-01, -5.9131e-02, -9.0713e-02,  4.9027e-02, -6.0626e-02],
        [ 2.5644e-02, -3.5141e-02, -9.4957e-03,  3.7131e-02, -9.5104e-03,
         -7.2848e-02, -6.3989e-0

tensor([[ 1.2299e-02, -1.6521e-02, -1.0059e-02,  3.5104e-02,  2.4587e-02,
         -6.9084e-02, -6.7141e-02, -6.2930e-02,  2.6072e-02, -6.2360e-02],
        [-6.4502e-02, -9.6949e-02,  4.5326e-02, -7.8073e-02,  8.6463e-02,
         -6.1354e-02,  1.5081e-02, -8.1866e-02,  7.1890e-02, -1.3203e-02],
        [-5.3219e-02, -1.1855e-01,  5.8436e-02, -5.6386e-02,  9.4884e-02,
         -1.1783e-01,  2.7142e-02, -1.2466e-01,  7.8377e-02, -6.2592e-02],
        [ 7.0317e-02, -8.0864e-02,  2.5846e-02, -1.4291e-02,  5.9408e-02,
         -1.4219e-01, -2.2323e-02, -1.1228e-01,  4.2263e-03, -4.6167e-02],
        [-7.2603e-02, -1.0679e-01, -2.9676e-03, -8.8401e-02,  3.4162e-02,
         -6.7713e-02, -3.2305e-02,  1.4399e-02,  8.3567e-02,  1.0156e-01],
        [ 1.5653e-02, -7.9122e-02, -2.3789e-02,  5.3314e-02,  2.7191e-02,
         -9.3753e-02, -4.7187e-02, -7.8702e-02,  4.5353e-02, -4.3735e-02],
        [ 1.7239e-02, -2.2904e-02,  7.7517e-03,  5.6146e-02,  3.6514e-02,
         -1.2995e-01, -6.2474e-0

tensor([[-4.2305e-02, -6.8194e-02,  2.9983e-03, -7.3250e-02,  1.9374e-02,
         -1.8312e-02, -1.8034e-03, -5.2796e-02,  7.1366e-02,  3.7342e-03],
        [-5.6760e-02, -9.1036e-02, -1.4361e-02, -8.4231e-02, -4.6591e-03,
         -4.1311e-02, -6.6963e-03,  6.2170e-02,  1.1670e-01,  3.5783e-02],
        [-2.9188e-02, -1.2092e-01,  1.6318e-02, -5.1330e-02,  4.0309e-02,
         -4.6119e-02, -5.8998e-03,  4.5005e-02,  1.1445e-01,  2.3781e-02],
        [-1.1046e-02, -1.7455e-01,  5.7677e-02, -1.1738e-01,  1.3259e-01,
         -1.7485e-01,  4.0900e-02, -2.4315e-01,  8.7956e-02, -3.7067e-02],
        [-3.2827e-02, -6.7900e-02, -7.6777e-03, -4.9050e-02,  2.7861e-02,
         -1.8763e-02,  1.3931e-02,  2.0249e-02,  5.7402e-02, -2.0339e-02],
        [-7.9606e-02, -1.2465e-01, -3.9206e-03, -1.0322e-01,  6.5180e-02,
         -1.0755e-01, -8.1542e-02, -2.2861e-02,  5.5592e-02,  1.9570e-01],
        [ 4.1339e-02, -3.0752e-02,  4.3311e-02,  6.8260e-02,  7.1750e-02,
         -1.8956e-01, -5.9923e-0

tensor([[ 2.7445e-02,  1.7822e-02,  1.3022e-02,  2.7319e-02,  4.3915e-02,
         -9.7718e-02, -4.0493e-02, -9.2706e-02,  7.7168e-03, -6.7656e-02],
        [-1.2848e-01, -2.0977e-01,  2.3902e-02, -1.8034e-01,  1.1209e-01,
         -1.2059e-01,  6.3938e-03, -6.7979e-02,  2.0975e-01,  9.2041e-02],
        [-1.7433e-02, -7.0635e-02,  4.9908e-02, -7.2270e-02,  8.4852e-02,
         -5.9965e-02,  2.9623e-02, -6.1568e-02,  6.1580e-02, -2.7732e-02],
        [-5.2516e-02, -9.4907e-02, -6.1349e-03, -8.3469e-02,  1.0853e-02,
         -3.8261e-02, -8.6510e-03,  6.0935e-02,  1.0606e-01,  3.8091e-03],
        [-2.7682e-02, -1.6086e-01,  1.2525e-01, -8.1347e-02,  2.0941e-01,
         -2.5757e-01,  6.2523e-02, -2.2097e-01,  6.6254e-02, -5.5225e-02],
        [ 3.5533e-02, -2.3915e-02,  1.1122e-02,  1.2330e-01,  5.4883e-02,
         -1.8414e-01, -8.0928e-02, -1.4510e-01,  6.1421e-03, -1.1740e-01],
        [-9.7356e-02, -2.0428e-01,  1.5773e-01, -1.7816e-01,  2.6414e-01,
         -2.8590e-01,  4.7673e-0

tensor([[-6.5139e-02, -9.7582e-02, -7.0762e-03, -9.5178e-02,  2.3589e-03,
         -3.4073e-02, -3.8715e-02, -2.7775e-02,  1.0647e-01,  4.5030e-02],
        [ 1.6985e-01, -8.5525e-02,  8.1062e-02, -8.8711e-03,  7.7592e-02,
         -2.8210e-01, -6.4202e-03, -2.3534e-01, -3.4806e-02, -7.8815e-02],
        [ 1.1111e-01, -2.7972e-02,  4.9040e-02,  9.2047e-02,  6.6825e-02,
         -2.4552e-01, -4.6832e-02, -2.1920e-01, -2.4011e-02, -1.0452e-01],
        [ 7.9590e-02,  1.9093e-02,  3.0075e-02,  7.1314e-02,  4.0963e-02,
         -2.1066e-01, -6.8827e-02, -1.3967e-01, -3.3419e-02, -9.9732e-02],
        [-9.0809e-02, -1.7280e-01, -3.5856e-02, -1.3055e-01,  1.0571e-01,
         -2.3441e-01, -1.1226e-01, -1.3529e-01,  8.1886e-02,  2.3302e-01],
        [ 1.0723e-01,  2.1774e-02,  4.6519e-02,  8.7224e-02,  7.7343e-02,
         -2.4837e-01, -9.8628e-02, -1.7145e-01, -2.4349e-02, -9.9036e-02],
        [-8.7666e-03, -1.4242e-01,  1.1760e-01, -7.6683e-02,  1.6948e-01,
         -1.9663e-01,  3.2871e-0

tensor([[ 1.3779e-01, -1.3308e-02,  9.7613e-02,  4.5829e-02,  9.1331e-02,
         -2.1097e-01, -1.5300e-02, -1.7285e-01, -5.6916e-02, -8.2303e-02],
        [-9.4433e-02, -1.4540e-01,  2.8003e-02, -1.6059e-01,  1.8047e-02,
         -5.2803e-02, -3.8407e-02,  1.0948e-01,  1.3643e-01,  4.9082e-02],
        [ 3.6754e-02, -1.7541e-01,  2.0156e-01, -8.7569e-02,  2.7738e-01,
         -3.8979e-01,  3.4193e-02, -3.1399e-01,  4.5468e-02, -1.2649e-01],
        [-3.4163e-02, -8.1347e-02,  9.3201e-02, -8.0002e-02,  1.2709e-01,
         -1.0501e-01,  3.3983e-03, -1.1223e-01,  4.9713e-02, -2.0533e-02],
        [-2.9756e-02, -2.0170e-01,  2.6194e-01, -1.3873e-01,  3.4345e-01,
         -3.9556e-01,  3.4093e-02, -3.5062e-01,  4.1982e-02, -7.8649e-02],
        [-8.0093e-02, -8.4152e-02,  2.4927e-02, -1.0865e-01,  1.9895e-02,
         -3.2343e-02, -3.8979e-03,  3.6883e-02,  1.1877e-01,  5.0499e-02],
        [-4.6239e-02, -2.1892e-01,  2.4749e-01, -1.9647e-01,  3.3005e-01,
         -3.4793e-01,  1.3998e-0

Test Error: 
 Accuracy: 41.1%, Avg loss: 2.165946 

Epoch 2
-------------------------------
tensor([[-1.3790e-01, -2.1632e-01, -8.8627e-03, -1.6831e-01,  8.9713e-02,
         -2.2967e-01, -1.0970e-01, -1.1493e-01,  1.4975e-01,  2.7700e-01],
        [ 2.1515e-01, -1.3094e-01,  1.2012e-01, -2.9763e-03,  1.2587e-01,
         -3.1913e-01,  2.1713e-02, -3.3072e-01, -4.6461e-02, -1.0455e-01],
        [ 4.8803e-02,  1.4786e-02,  4.7905e-02,  5.1280e-02,  2.7240e-02,
         -1.1740e-01, -3.8788e-02, -1.0475e-01, -1.3028e-02, -7.0435e-02],
        [ 8.6444e-02, -2.3709e-02,  6.0280e-02,  6.7506e-02,  8.0762e-02,
         -1.9617e-01,  3.2226e-03, -1.7853e-01, -1.2130e-02, -8.2707e-02],
        [ 1.1653e-01,  2.2658e-02,  2.9989e-02,  1.6684e-01,  7.8008e-02,
         -2.8270e-01, -8.6152e-02, -2.3714e-01, -3.7625e-02, -1.4904e-01],
        [-1.5933e-02, -1.6921e-01,  2.1289e-01, -1.5894e-01,  2.6170e-01,
         -3.2930e-01,  1.1358e-02, -3.3267e-01,  5.8335e-02, -5.7604e-02],
        [-7.07

tensor([[ 4.3753e-02, -1.7741e-02,  5.7618e-02,  1.9857e-02,  6.1241e-02,
         -7.4526e-02, -2.6293e-02, -6.7276e-02,  2.2657e-02, -5.3093e-02],
        [-7.6960e-02, -2.3513e-01,  2.4357e-01, -2.2875e-01,  3.5298e-01,
         -3.0547e-01,  5.4898e-02, -2.9868e-01,  1.1002e-01, -4.6898e-02],
        [-9.7645e-03, -5.8646e-02,  4.3089e-02, -4.6433e-02,  7.4613e-02,
         -6.9469e-02,  3.8388e-02, -7.4061e-02,  5.2811e-02, -4.2318e-02],
        [-1.8661e-01, -2.4395e-01,  5.3026e-02, -2.4215e-01,  1.6955e-01,
         -1.8182e-01, -9.8498e-03, -6.3627e-02,  2.5672e-01,  1.2093e-01],
        [-1.3414e-03, -6.4928e-02,  8.3520e-02, -4.5032e-02,  1.5965e-01,
         -1.5013e-01,  3.6383e-02, -1.4106e-01,  2.9079e-02, -5.8169e-02],
        [-1.0849e-01, -1.4717e-01,  2.5202e-02, -1.1785e-01,  3.2903e-02,
         -3.3430e-02, -4.2494e-02,  1.2658e-01,  1.5729e-01,  1.0620e-01],
        [-6.0956e-02, -1.3947e-01,  1.0170e-01, -1.1784e-01,  1.2878e-01,
         -1.7954e-01,  1.0348e-0

tensor([[-1.6042e-02, -5.2855e-02,  3.9367e-02, -6.3977e-02,  1.6916e-02,
         -9.5692e-03, -1.1275e-02,  1.1094e-02,  5.9105e-02, -5.4943e-02],
        [-7.9378e-02, -1.4681e-01,  4.4653e-02, -1.6101e-01,  6.0981e-02,
         -8.8751e-02, -2.5080e-02, -7.1629e-02,  1.1297e-01,  7.5225e-02],
        [ 9.6076e-02,  3.1345e-02,  2.6392e-02,  1.5505e-01,  5.3430e-02,
         -3.1253e-01, -7.7008e-02, -2.3760e-01, -4.7771e-02, -1.1964e-01],
        [ 8.3865e-03, -9.2515e-02,  1.1412e-01, -5.4473e-02,  1.3887e-01,
         -1.9406e-01,  4.2263e-02, -1.8102e-01,  5.3145e-02, -8.4142e-02],
        [-1.6474e-01, -2.4620e-01,  6.8294e-02, -2.1793e-01,  1.6531e-01,
         -1.7514e-01,  1.1971e-02, -1.1189e-01,  2.7282e-01,  7.8076e-02],
        [-8.0544e-02, -1.7466e-01,  5.7341e-03, -1.5464e-01,  3.2480e-02,
         -3.0988e-02, -5.5847e-02,  1.3863e-01,  1.1971e-01,  1.8152e-01],
        [-1.2212e-01, -2.2090e-01, -1.6087e-03, -1.4057e-01,  1.2361e-01,
         -3.0539e-01, -9.2020e-0

tensor([[-6.7927e-02, -1.6865e-01,  2.1769e-01, -1.5699e-01,  2.7648e-01,
         -2.3508e-01,  8.4356e-02, -2.3753e-01,  8.8185e-02, -7.4208e-02],
        [-1.0075e-01, -1.4062e-01, -2.1785e-02, -1.6500e-01, -9.2818e-04,
         -7.2446e-02, -5.0501e-02,  7.1669e-02,  1.1047e-01,  2.3117e-01],
        [-1.5847e-01, -2.4936e-01,  3.6302e-04, -2.4473e-01,  8.0647e-02,
         -1.6083e-01, -9.9377e-02, -5.0596e-02,  1.7601e-01,  3.8223e-01],
        [-1.5532e-01, -2.0180e-01,  1.7183e-02, -1.7250e-01,  2.0205e-02,
         -1.9111e-02, -6.4040e-02,  1.7921e-01,  1.8870e-01,  1.4474e-01],
        [-1.0662e-01, -1.3710e-01,  1.2649e-02, -1.3056e-01,  3.2005e-02,
         -1.2397e-02, -3.1423e-02,  4.9044e-02,  1.4713e-01,  5.0244e-02],
        [-1.4989e-01, -1.7221e-01,  1.0885e-03, -1.8979e-01,  2.0123e-02,
         -5.4066e-02, -4.5108e-02,  1.9279e-01,  1.3949e-01,  1.4821e-01],
        [ 8.8600e-02,  1.5793e-01,  4.6639e-02,  1.6749e-01,  4.5026e-02,
         -2.5723e-01, -6.2256e-0

tensor([[-9.8647e-02, -1.8820e-01,  7.9758e-03, -1.8130e-01,  4.8102e-02,
         -2.0147e-01, -6.3360e-02, -1.4206e-01,  1.2360e-01,  2.7769e-01],
        [ 2.0424e-01, -7.9959e-02,  2.0864e-01,  5.1296e-02,  2.7026e-01,
         -4.9973e-01,  5.7833e-02, -4.8022e-01, -4.2638e-02, -2.6009e-01],
        [ 3.9453e-01, -3.2663e-02,  1.8746e-01,  1.5814e-01,  2.1185e-01,
         -5.9532e-01,  4.2533e-02, -5.6278e-01, -1.7193e-01, -2.6573e-01],
        [-1.0582e-01, -1.2003e-01, -9.1331e-03, -1.3441e-01, -1.3804e-02,
         -9.3129e-03, -3.3099e-02,  1.5179e-01,  1.2478e-01,  6.4185e-02],
        [ 6.1411e-02, -7.7847e-02,  1.6301e-01, -2.1485e-02,  2.4714e-01,
         -3.0719e-01,  5.5165e-02, -2.2346e-01, -1.6208e-02, -5.0545e-02],
        [ 2.7636e-01,  1.0446e-01,  1.4386e-01,  2.9601e-01,  2.1752e-01,
         -5.6461e-01, -1.0348e-02, -4.8326e-01, -1.6722e-01, -3.1921e-01],
        [ 1.7036e-01,  2.7246e-01,  6.4607e-02,  2.7670e-01,  3.9019e-02,
         -3.7703e-01, -8.8090e-0

tensor([[ 0.1030,  0.1515, -0.0048,  0.2314,  0.0160, -0.2527, -0.0850, -0.2164,
         -0.0653, -0.1595],
        [-0.1098, -0.2384,  0.2130, -0.2518,  0.2268, -0.1970,  0.0690, -0.1848,
          0.1507,  0.0134],
        [ 0.0183, -0.2009,  0.2518, -0.0813,  0.3206, -0.4170,  0.1233, -0.4201,
          0.0692, -0.1877],
        [ 0.3497,  0.0011,  0.1683,  0.1378,  0.1912, -0.5072,  0.0595, -0.4605,
         -0.1694, -0.2638],
        [-0.2370, -0.3161, -0.0417, -0.3376,  0.0134, -0.1223, -0.1123,  0.1167,
          0.2466,  0.5259],
        [ 0.1810,  0.0811,  0.0211,  0.2834,  0.0648, -0.3956, -0.0422, -0.3480,
         -0.0529, -0.2183],
        [ 0.1749,  0.2172,  0.0477,  0.3593,  0.0745, -0.4341, -0.0900, -0.3593,
         -0.1508, -0.2376],
        [ 0.3286, -0.0117,  0.1736,  0.1721,  0.2813, -0.6526,  0.0699, -0.5185,
         -0.1561, -0.3302],
        [ 0.1228,  0.2111,  0.0375,  0.2083,  0.0276, -0.2749, -0.0468, -0.2629,
         -0.1235, -0.1896],
        [-0.0739, -

tensor([[ 1.7535e-01,  2.9238e-01,  6.0893e-02,  2.4890e-01,  7.9574e-02,
         -3.4247e-01, -3.0213e-02, -3.3067e-01, -1.6034e-01, -2.9274e-01],
        [-3.8056e-01, -5.8375e-01,  1.0488e-01, -5.8105e-01,  1.8363e-01,
         -2.4968e-01, -3.5359e-02,  5.2439e-02,  5.6431e-01,  3.8765e-01],
        [-2.8167e-02, -1.4981e-01,  1.5922e-01, -1.6194e-01,  1.9885e-01,
         -1.7187e-01,  8.5633e-02, -1.5945e-01,  1.0441e-01, -5.5987e-02],
        [-1.7276e-01, -2.0775e-01, -3.1354e-02, -2.0652e-01, -2.3882e-02,
         -4.0305e-03, -4.8704e-02,  2.6223e-01,  1.9673e-01,  1.0844e-01],
        [ 1.0344e-01, -2.6435e-01,  4.0020e-01, -9.3668e-02,  5.5331e-01,
         -7.2241e-01,  2.3002e-01, -6.6712e-01,  5.5344e-02, -2.7551e-01],
        [ 2.6028e-01,  2.5166e-01,  5.4496e-02,  4.8194e-01,  1.1304e-01,
         -5.7232e-01, -6.2377e-02, -4.9970e-01, -1.8777e-01, -3.8663e-01],
        [-8.7335e-02, -4.2479e-01,  5.5396e-01, -3.7753e-01,  6.5427e-01,
         -7.8798e-01,  1.7053e-0

tensor([[-0.1855, -0.2261, -0.0441, -0.2476, -0.0429, -0.0057, -0.0843,  0.0493,
          0.2177,  0.2822],
        [ 0.6111,  0.0843,  0.2607,  0.2253,  0.2228, -0.7880,  0.1281, -0.7647,
         -0.2856, -0.4021],
        [ 0.4549,  0.2537,  0.1824,  0.4246,  0.1863, -0.7228,  0.0407, -0.7119,
         -0.2749, -0.4669],
        [ 0.3380,  0.3442,  0.1149,  0.3810,  0.1099, -0.5915, -0.0291, -0.5280,
         -0.2606, -0.4051],
        [-0.2550, -0.4563, -0.0422, -0.4344,  0.0841, -0.3559, -0.1942, -0.1343,
          0.2680,  0.8303],
        [ 0.4188,  0.3998,  0.1316,  0.5023,  0.1609, -0.7149, -0.0515, -0.6287,
         -0.3323, -0.4657],
        [ 0.0855, -0.1872,  0.3520, -0.0808,  0.4146, -0.5434,  0.1693, -0.5349,
          0.0451, -0.2608],
        [-0.1302, -0.2444,  0.0449, -0.2479,  0.0855, -0.1188,  0.0120, -0.0049,
          0.2535,  0.1312],
        [-0.1411, -0.2502,  0.0528, -0.2471,  0.0476, -0.0428, -0.0485,  0.1115,
          0.2527,  0.1377],
        [-0.0550, -

Test Error: 
 Accuracy: 58.4%, Avg loss: 1.908212 

Epoch 3
-------------------------------
tensor([[-0.4245, -0.6194, -0.0430, -0.6130,  0.0297, -0.2670, -0.2384,  0.0407,
          0.4721,  0.9672],
        [ 0.7155, -0.0093,  0.3192,  0.1961,  0.3023, -0.8573,  0.2140, -0.9076,
         -0.2863, -0.4532],
        [ 0.2000,  0.2384,  0.0741,  0.2711,  0.0376, -0.3318, -0.0263, -0.3262,
         -0.1734, -0.2774],
        [ 0.3443,  0.1787,  0.1638,  0.2888,  0.1875, -0.5318,  0.0911, -0.5419,
         -0.2039, -0.3788],
        [ 0.4309,  0.4376,  0.0790,  0.6173,  0.1381, -0.7382, -0.0406, -0.7062,
         -0.2933, -0.5393],
        [ 0.0739, -0.2937,  0.5861, -0.2695,  0.5911, -0.7665,  0.1888, -0.7586,
          0.0793, -0.2702],
        [-0.2853, -0.3566, -0.0595, -0.3676, -0.0640,  0.0624, -0.0897,  0.4190,
          0.3108,  0.3389],
        [-0.1727, -0.6042,  0.6345, -0.6148,  0.7106, -0.8428,  0.1816, -0.7148,
          0.3507,  0.0483],
        [-0.0437, -0.1015, -0.0084, 

tensor([[-8.3985e-02, -8.8543e-02,  2.4008e-02, -1.1625e-01, -2.1384e-02,
          2.6076e-02, -3.0940e-02,  1.1098e-01,  1.1297e-01, -2.7369e-02],
        [-2.8247e-01, -4.0244e-01,  2.8135e-02, -4.4371e-01,  2.8255e-02,
         -5.2360e-02, -7.1376e-02,  6.9565e-02,  3.2153e-01,  3.8881e-01],
        [ 3.6864e-01,  4.4471e-01, -4.6809e-03,  5.9770e-01,  4.0199e-02,
         -7.3659e-01, -3.3758e-02, -6.2428e-01, -2.9116e-01, -4.4524e-01],
        [ 7.7237e-02, -1.0847e-01,  2.5400e-01, -4.3830e-02,  2.8055e-01,
         -4.2675e-01,  1.6548e-01, -4.3023e-01,  3.3297e-02, -2.7213e-01],
        [-4.8091e-01, -6.8221e-01,  1.6122e-01, -6.2437e-01,  2.3641e-01,
         -2.4851e-01,  1.2021e-02, -2.1178e-02,  6.8220e-01,  3.6563e-01],
        [-3.6440e-01, -4.5498e-01, -1.2469e-01, -4.7239e-01, -1.4172e-01,
          1.2406e-01, -1.8425e-01,  5.8251e-01,  3.1840e-01,  6.4139e-01],
        [-4.0138e-01, -5.7805e-01, -7.0564e-02, -5.2907e-01,  1.8468e-02,
         -3.6435e-01, -2.0232e-0

tensor([[-9.7409e-02, -3.5177e-01,  4.7474e-01, -3.1728e-01,  5.1926e-01,
         -4.8963e-01,  2.5020e-01, -4.7900e-01,  1.6346e-01, -2.1069e-01],
        [-3.9919e-01, -3.9372e-01, -1.6290e-01, -4.7232e-01, -1.8257e-01,
          8.3897e-02, -1.8831e-01,  4.3471e-01,  3.4829e-01,  7.1885e-01],
        [-5.3601e-01, -7.3096e-01, -1.2919e-01, -7.6469e-01, -9.0979e-02,
         -2.0075e-02, -2.7073e-01,  3.0365e-01,  5.5360e-01,  1.1824e+00],
        [-5.5328e-01, -5.6696e-01, -1.3317e-01, -5.7549e-01, -2.1204e-01,
          2.1673e-01, -2.4030e-01,  7.6648e-01,  5.0151e-01,  6.3696e-01],
        [-3.6911e-01, -3.9473e-01, -4.0055e-02, -4.1515e-01, -7.1355e-02,
          1.2021e-01, -1.1613e-01,  3.5184e-01,  3.8375e-01,  3.1619e-01],
        [-5.0913e-01, -4.6696e-01, -1.4159e-01, -5.3264e-01, -1.8017e-01,
          1.5502e-01, -1.9859e-01,  7.5785e-01,  3.9766e-01,  5.7160e-01],
        [ 3.7222e-01,  6.9843e-01,  5.8879e-02,  6.5216e-01,  5.1202e-02,
         -6.3419e-01, -3.2918e-0

tensor([[-3.0934e-01, -5.1648e-01, -5.1993e-02, -5.1819e-01, -6.4316e-02,
         -1.7589e-01, -1.3905e-01, -4.4777e-02,  3.6363e-01,  8.4468e-01],
        [ 6.8532e-01,  1.6078e-01,  4.7316e-01,  3.7234e-01,  5.2661e-01,
         -1.1356e+00,  3.6281e-01, -1.1855e+00, -3.0302e-01, -8.9982e-01],
        [ 1.1496e+00,  4.1379e-01,  4.0075e-01,  6.9461e-01,  4.1366e-01,
         -1.3722e+00,  3.8919e-01, -1.4921e+00, -6.3802e-01, -9.9855e-01],
        [-3.6200e-01, -3.1665e-01, -1.2023e-01, -3.7222e-01, -1.7278e-01,
          1.7083e-01, -1.4788e-01,  5.6277e-01,  3.1056e-01,  3.3340e-01],
        [ 1.9871e-01, -1.0644e-02,  3.0284e-01,  4.9371e-02,  4.1862e-01,
         -6.3082e-01,  2.0074e-01, -5.1241e-01, -1.3039e-01, -2.8567e-01],
        [ 9.2968e-01,  7.5356e-01,  2.8681e-01,  1.0317e+00,  3.9746e-01,
         -1.3418e+00,  2.3801e-01, -1.3732e+00, -7.0241e-01, -1.1086e+00],
        [ 5.6770e-01,  1.1539e+00,  3.9386e-02,  9.6935e-01,  8.4537e-03,
         -9.2710e-01, -4.9707e-0

tensor([[ 3.3210e-01,  6.8916e-01, -1.1039e-01,  7.3557e-01, -7.9809e-02,
         -5.5107e-01, -1.0756e-01, -5.1914e-01, -3.2598e-01, -4.8945e-01],
        [-2.9946e-01, -6.2427e-01,  4.1876e-01, -6.4430e-01,  3.4981e-01,
         -2.3582e-01,  1.8198e-01, -1.7548e-01,  4.0785e-01,  9.1031e-02],
        [ 1.7073e-01, -3.0682e-01,  5.4486e-01, -8.6335e-02,  6.3610e-01,
         -8.4185e-01,  4.0316e-01, -9.0174e-01,  6.9745e-02, -5.4713e-01],
        [ 9.4934e-01,  3.9774e-01,  3.2996e-01,  5.8987e-01,  3.3807e-01,
         -1.1210e+00,  3.4417e-01, -1.1945e+00, -5.5883e-01, -9.0735e-01],
        [-8.2159e-01, -9.6481e-01, -2.8948e-01, -1.0212e+00, -2.9674e-01,
          1.9964e-01, -3.8654e-01,  7.8073e-01,  7.5815e-01,  1.5725e+00],
        [ 5.6704e-01,  6.0372e-01, -2.0595e-02,  8.4839e-01,  6.3906e-02,
         -8.6640e-01,  5.5809e-02, -8.7012e-01, -3.1866e-01, -7.0891e-01],
        [ 5.8059e-01,  9.5236e-01, -1.5421e-02,  1.0928e+00,  5.0003e-02,
         -9.5824e-01, -4.2869e-0

tensor([[-4.4484e-01, -5.4944e-01, -3.8271e-02, -5.6856e-01, -4.3491e-02,
          1.3636e-01, -1.0701e-01,  2.2461e-01,  4.9862e-01,  5.1804e-01],
        [-6.4373e-01, -5.9472e-01, -2.2177e-01, -6.7377e-01, -3.0113e-01,
          2.9514e-01, -2.7590e-01,  9.9085e-01,  5.8118e-01,  6.6141e-01],
        [-6.5135e-01, -7.4942e-01, -2.1704e-01, -7.1882e-01, -2.5916e-01,
          3.1410e-01, -2.7791e-01,  1.0040e+00,  6.0501e-01,  8.5063e-01],
        [ 3.5469e-01, -5.9675e-01,  7.2898e-01, -3.1960e-01,  7.9446e-01,
         -1.1076e+00,  5.8786e-01, -1.2943e+00,  2.1984e-01, -5.7855e-01],
        [-3.6490e-01, -3.4802e-01, -1.2868e-01, -4.0678e-01, -1.3984e-01,
          2.3433e-01, -1.2040e-01,  5.2014e-01,  3.0318e-01,  3.5283e-01],
        [-8.7920e-01, -1.0617e+00, -2.9154e-01, -1.1475e+00, -3.1786e-01,
          1.8707e-01, -4.6687e-01,  7.7935e-01,  7.6515e-01,  1.8786e+00],
        [ 8.9514e-01,  9.2798e-01,  2.2355e-01,  1.1851e+00,  3.5071e-01,
         -1.3514e+00,  2.3170e-0

tensor([[ 0.5454,  1.0257,  0.0532,  0.8307,  0.0924, -0.7874,  0.0542, -0.8624,
         -0.5755, -0.8931],
        [-1.1244, -1.5789,  0.0533, -1.5551,  0.0249,  0.0948, -0.2103,  0.8045,
          1.4660,  1.1624],
        [-0.0931, -0.3629,  0.3397, -0.3617,  0.3542, -0.2539,  0.2098, -0.2605,
          0.2305, -0.1411],
        [-0.5234, -0.5270, -0.1773, -0.5372, -0.2275,  0.2746, -0.2076,  0.8486,
          0.4755,  0.4570],
        [ 0.3404, -0.3991,  0.8428, -0.0902,  1.0485, -1.3252,  0.6624, -1.3891,
          0.0672, -0.8618],
        [ 0.7861,  1.0415, -0.0046,  1.2801,  0.1434, -1.1994,  0.0856, -1.2196,
         -0.6383, -1.1199],
        [-0.1005, -0.9025,  1.1464, -0.7663,  1.1762, -1.3107,  0.5964, -1.2283,
          0.4087, -0.5520],
        [ 0.5951,  1.1266,  0.0314,  0.8864,  0.0548, -0.8345,  0.0476, -0.9325,
         -0.6081, -0.9501],
        [ 0.2762, -0.6978,  0.8673, -0.4027,  1.0047, -1.2118,  0.6746, -1.3148,
          0.2981, -0.6424],
        [-0.2050, -

tensor([[ 1.0823e+00,  7.8135e-01,  3.3823e-01,  8.4124e-01,  2.8611e-01,
         -1.1478e+00,  3.4391e-01, -1.2930e+00, -7.6747e-01, -1.0896e+00],
        [-9.0805e-01, -9.3130e-01, -2.3258e-01, -9.8332e-01, -3.9591e-01,
          5.2090e-01, -3.9283e-01,  1.3545e+00,  8.3419e-01,  8.3999e-01],
        [ 7.4732e-01, -1.2434e-01,  1.0301e+00,  2.2883e-01,  1.2373e+00,
         -1.7772e+00,  8.2755e-01, -1.9024e+00, -2.2919e-01, -1.4574e+00],
        [-7.4330e-02, -3.0249e-01,  4.5772e-01, -2.7235e-01,  4.2201e-01,
         -3.7582e-01,  2.0679e-01, -4.3281e-01,  1.4750e-01, -2.6096e-01],
        [ 3.8119e-01, -5.5799e-01,  1.5251e+00, -3.3648e-01,  1.4699e+00,
         -1.7345e+00,  8.6730e-01, -1.8789e+00, -6.0003e-03, -1.2285e+00],
        [-7.0173e-01, -6.9584e-01, -1.5590e-01, -7.6784e-01, -2.7787e-01,
          4.0434e-01, -2.5915e-01,  8.8852e-01,  6.5993e-01,  7.0677e-01],
        [ 9.5463e-02, -7.6951e-01,  1.3664e+00, -6.2227e-01,  1.3393e+00,
         -1.4272e+00,  7.3549e-0

Test Error: 
 Accuracy: 62.5%, Avg loss: 1.529228 

Epoch 4
-------------------------------
tensor([[-1.2892e+00, -1.6433e+00, -3.1056e-01, -1.6321e+00, -3.5020e-01,
          3.4282e-01, -6.0601e-01,  9.5047e-01,  1.2931e+00,  2.3367e+00],
        [ 1.5749e+00,  3.4388e-01,  5.8268e-01,  6.5249e-01,  5.5704e-01,
         -1.6233e+00,  6.8650e-01, -1.9396e+00, -6.9789e-01, -1.3078e+00],
        [ 5.2190e-01,  7.5584e-01,  5.8275e-02,  7.2698e-01,  5.6432e-02,
         -6.6555e-01,  6.8728e-02, -7.5606e-01, -5.0383e-01, -7.5203e-01],
        [ 8.4380e-01,  6.6799e-01,  2.8566e-01,  7.6806e-01,  3.5235e-01,
         -1.0479e+00,  3.4864e-01, -1.2368e+00, -5.8273e-01, -1.0675e+00],
        [ 1.0432e+00,  1.3920e+00,  2.7119e-02,  1.5144e+00,  1.7981e-01,
         -1.4331e+00,  1.5161e-01, -1.5605e+00, -8.3022e-01, -1.4348e+00],
        [ 1.9771e-01, -5.9810e-01,  1.1457e+00, -4.7261e-01,  1.0583e+00,
         -1.2228e+00,  6.2408e-01, -1.3887e+00,  1.8124e-01, -7.8828e-01],
        [-8.52

tensor([[ 3.4698e-01,  3.0029e-01,  1.0237e-01,  3.3807e-01,  6.7994e-02,
         -3.8090e-01,  8.1657e-02, -4.3564e-01, -1.7484e-01, -4.9116e-01],
        [-2.4722e-01, -1.1565e+00,  1.1814e+00, -1.0101e+00,  1.1944e+00,
         -9.5618e-01,  6.7045e-01, -1.0333e+00,  5.8530e-01, -4.7876e-01],
        [ 7.5639e-02, -1.7025e-01,  2.3285e-01, -1.0424e-01,  2.7560e-01,
         -2.6315e-01,  2.4440e-01, -3.4754e-01,  4.1148e-02, -3.0386e-01],
        [-1.4198e+00, -1.8548e+00,  7.9035e-02, -1.8074e+00,  6.1165e-03,
          3.3865e-01, -2.5315e-01,  1.0320e+00,  1.7219e+00,  1.2664e+00],
        [ 2.1094e-01, -6.2888e-02,  4.7698e-01,  2.6759e-02,  5.7720e-01,
         -6.3480e-01,  4.0617e-01, -7.9742e-01, -9.2970e-02, -6.5504e-01],
        [-1.1058e+00, -1.0475e+00, -3.9037e-01, -1.0715e+00, -5.3359e-01,
          7.3297e-01, -5.2699e-01,  1.6126e+00,  9.2247e-01,  1.1775e+00],
        [-1.0078e-01, -5.5385e-01,  4.7444e-01, -4.6324e-01,  4.2625e-01,
         -4.8042e-01,  3.0217e-0

tensor([[-2.4740e-01, -2.1901e-01, -2.5867e-02, -2.4619e-01, -1.0469e-01,
          2.0119e-01, -8.0902e-02,  3.5514e-01,  2.4317e-01,  2.3297e-02],
        [-7.3048e-01, -1.0099e+00, -3.6487e-02, -1.0023e+00, -1.1048e-01,
          3.2870e-01, -1.9326e-01,  5.2835e-01,  7.6713e-01,  9.1070e-01],
        [ 8.1524e-01,  1.2235e+00, -1.7685e-01,  1.3557e+00, -1.1217e-02,
         -1.2327e+00,  1.0481e-01, -1.2353e+00, -6.9393e-01, -1.0920e+00],
        [ 1.8153e-01, -1.9727e-01,  4.8745e-01, -6.4442e-02,  5.2260e-01,
         -6.5397e-01,  4.4124e-01, -8.1135e-01,  5.0994e-02, -6.6550e-01],
        [-1.1572e+00, -1.6739e+00,  2.4982e-01, -1.4582e+00,  2.4433e-01,
          1.2959e-01, -3.7230e-03,  4.7068e-01,  1.5225e+00,  7.7967e-01],
        [-9.9951e-01, -1.0831e+00, -4.5454e-01, -1.0964e+00, -5.6040e-01,
          7.6290e-01, -5.2050e-01,  1.6191e+00,  7.6624e-01,  1.4218e+00],
        [-1.1057e+00, -1.4738e+00, -3.3399e-01, -1.3814e+00, -3.2047e-01,
          2.0226e-01, -4.8140e-0

tensor([[-0.1503, -0.7690,  0.9115, -0.6084,  0.9069, -0.6572,  0.6011, -0.8345,
          0.3559, -0.5725],
        [-0.9659, -0.9628, -0.4675, -1.0248, -0.5589,  0.6874, -0.5059,  1.2649,
          0.7792,  1.4594],
        [-1.3342, -1.7247, -0.4201, -1.7078, -0.4971,  0.7492, -0.6298,  1.2405,
          1.2739,  2.3986],
        [-1.3419, -1.3178, -0.4853, -1.3095, -0.6829,  1.0185, -0.6460,  2.0060,
          1.0910,  1.4089],
        [-0.8773, -0.9521, -0.1680, -0.9395, -0.2806,  0.6104, -0.3045,  1.0220,
          0.8346,  0.7108],
        [-1.2203, -1.0986, -0.4700, -1.1721, -0.6119,  0.8917, -0.5705,  1.9360,
          0.9052,  1.2128],
        [ 0.8484,  1.6971, -0.0127,  1.4895,  0.0577, -1.1815,  0.0963, -1.3548,
         -1.0006, -1.3711],
        [ 1.5895,  1.4235,  0.5278,  1.6724,  0.7056, -2.1098,  0.7501, -2.4357,
         -1.2110, -2.1813],
        [-1.3756, -1.8623,  0.0660, -1.7774, -0.0431,  0.5065, -0.2358,  1.0665,
          1.7683,  1.1137],
        [-0.7742, -

tensor([[ 0.6372,  1.5144, -0.3620,  1.4865, -0.2197, -0.8714, -0.1253, -0.9375,
         -0.7237, -0.9815],
        [-0.5721, -1.3440,  0.7756, -1.2399,  0.5484, -0.0345,  0.3860, -0.1469,
          0.8335,  0.0315],
        [ 0.3729, -0.6456,  1.0279, -0.1800,  1.1346, -1.1940,  0.9202, -1.6291,
          0.1720, -1.1955],
        [ 1.8395,  0.8734,  0.5480,  1.2055,  0.6015, -1.8620,  0.8735, -2.3222,
         -1.0488, -1.8963],
        [-1.7989, -2.0991, -0.7312, -2.0560, -0.8499,  1.2684, -0.8908,  2.0566,
          1.5569,  2.9192],
        [ 1.0947,  1.3088, -0.1602,  1.6397,  0.0708, -1.3751,  0.2480, -1.6176,
         -0.6748, -1.4328],
        [ 1.1470,  2.0524, -0.2271,  2.1478,  0.0098, -1.5766,  0.0776, -1.7639,
         -1.2152, -1.5987],
        [ 1.8594,  0.9650,  0.5973,  1.4608,  0.8598, -2.2590,  1.0044, -2.6701,
         -1.0465, -2.2689],
        [ 0.8719,  1.8664, -0.1095,  1.4949, -0.0221, -1.1468,  0.1078, -1.3676,
         -1.0376, -1.3789],
        [-0.6657, -

tensor([[ 1.0109,  2.0025, -0.0148,  1.5896,  0.1393, -1.3018,  0.2137, -1.5865,
         -1.1233, -1.6520],
        [-2.1151, -3.1284,  0.0507, -2.8105, -0.1347,  1.1330, -0.4121,  1.8017,
          2.6797,  1.7941],
        [-0.1667, -0.7505,  0.6455, -0.6343,  0.5938, -0.2409,  0.4272, -0.4597,
          0.4400, -0.3569],
        [-1.0344, -1.0155, -0.3972, -0.9764, -0.5132,  0.8953, -0.4671,  1.7037,
          0.8535,  0.8108],
        [ 0.6111, -0.8893,  1.5212, -0.2131,  1.7539, -1.7824,  1.3544, -2.4210,
          0.2397, -1.7285],
        [ 1.4159,  1.9793, -0.1569,  2.2766,  0.2020, -1.8532,  0.3159, -2.1631,
         -1.1534, -2.0096],
        [-0.0412, -1.8161,  2.0666, -1.3107,  1.9575, -1.5985,  1.3247, -2.1343,
          0.8157, -1.3358],
        [ 1.0784,  2.1902, -0.0639,  1.6923,  0.0891, -1.3714,  0.2077, -1.6929,
         -1.1872, -1.7402],
        [ 0.4796, -1.4370,  1.5843, -0.7261,  1.6828, -1.5621,  1.3825, -2.2787,
          0.6225, -1.3958],
        [-0.4510, -

tensor([[ 1.8367,  1.2994,  0.4474,  1.4619,  0.4577, -1.7459,  0.7346, -2.2507,
         -1.2831, -1.8454],
        [-1.6415, -1.6633, -0.5265, -1.6241, -0.7762,  1.4446, -0.7531,  2.4895,
          1.4014,  1.2995],
        [ 1.2672, -0.4252,  1.7204,  0.3838,  2.0065, -2.4386,  1.6419, -3.2550,
         -0.2154, -2.5826],
        [-0.0790, -0.5981,  0.7723, -0.4289,  0.6803, -0.4196,  0.4448, -0.7526,
          0.2688, -0.5442],
        [ 0.7477, -1.2051,  2.5432, -0.5251,  2.3783, -2.2852,  1.7566, -3.2587,
          0.1161, -2.2886],
        [-1.2545, -1.2726, -0.3489, -1.2710, -0.5333,  1.0942, -0.5112,  1.6489,
          1.0890,  1.0974],
        [ 0.2699, -1.5152,  2.2929, -0.9629,  2.1765, -1.8025,  1.5046, -2.6429,
          0.4524, -1.7411],
        [ 2.9286,  0.9879,  0.9799,  1.6023,  1.0608, -2.7812,  1.4533, -3.7303,
         -1.4187, -2.8023],
        [-1.7398, -1.9250, -0.7538, -1.8604, -0.9322,  1.6239, -0.9239,  2.3135,
          1.4397,  2.3374],
        [-1.5366, -

tensor([[ 5.7509e-01,  4.5312e-01,  1.0126e-01,  5.7293e-01,  8.8945e-02,
         -5.0766e-01,  2.1372e-01, -7.3720e-01, -2.9088e-01, -8.0747e-01],
        [-3.1743e-01, -2.1157e+00,  1.9646e+00, -1.5492e+00,  1.8702e+00,
         -1.0318e+00,  1.2823e+00, -1.7742e+00,  1.0414e+00, -1.0167e+00],
        [ 1.3972e-01, -3.6948e-01,  4.2532e-01, -1.7058e-01,  4.6434e-01,
         -3.1362e-01,  4.5277e-01, -6.1005e-01,  9.3171e-02, -5.3462e-01],
        [-2.2779e+00, -3.2581e+00,  1.3850e-01, -2.8430e+00, -6.7001e-02,
          1.3208e+00, -3.7790e-01,  1.7176e+00,  2.8020e+00,  1.7088e+00],
        [ 3.7174e-01, -2.1457e-01,  8.1194e-01,  4.2915e-02,  9.3285e-01,
         -8.6088e-01,  7.5865e-01, -1.3581e+00, -9.2598e-02, -1.1246e+00],
        [-1.9165e+00, -1.7959e+00, -7.6352e-01, -1.7289e+00, -9.7916e-01,
          1.7799e+00, -9.6245e-01,  2.8353e+00,  1.4674e+00,  1.7765e+00],
        [-8.1459e-02, -1.0682e+00,  8.0985e-01, -7.1072e-01,  6.9410e-01,
         -4.7313e-01,  6.2492e-0

tensor([[-4.3642e-01, -3.9108e-01, -8.2553e-02, -3.7298e-01, -1.7820e-01,
          4.6887e-01, -1.4124e-01,  6.1506e-01,  3.8340e-01,  2.8328e-02],
        [-1.1730e+00, -1.7000e+00, -9.4026e-02, -1.5337e+00, -2.3377e-01,
          9.2674e-01, -3.1971e-01,  8.9963e-01,  1.2122e+00,  1.3489e+00],
        [ 1.2050e+00,  1.8458e+00, -4.2942e-01,  2.1135e+00, -2.1290e-02,
         -1.6003e+00,  2.5047e-01, -1.8844e+00, -1.0101e+00, -1.6030e+00],
        [ 3.1426e-01, -4.4116e-01,  7.8727e-01, -1.0716e-01,  8.3420e-01,
         -8.1891e-01,  7.8780e-01, -1.3228e+00,  1.4144e-01, -1.0812e+00],
        [-1.8622e+00, -2.8936e+00,  3.9873e-01, -2.2662e+00,  3.4470e-01,
          8.3936e-01,  2.8549e-02,  7.9019e-01,  2.4791e+00,  9.9261e-01],
        [-1.6915e+00, -1.7407e+00, -8.6263e-01, -1.6794e+00, -1.0238e+00,
          1.7570e+00, -9.4245e-01,  2.7491e+00,  1.1554e+00,  2.1151e+00],
        [-1.8074e+00, -2.5481e+00, -6.6510e-01, -2.1557e+00, -6.8718e-01,
          1.2239e+00, -7.9775e-0

tensor([[-0.1466, -1.3119,  1.4614, -0.8578,  1.3958, -0.7632,  1.0294, -1.4254,
          0.6027, -1.0346],
        [-1.5594, -1.5388, -0.8084, -1.5130, -0.9375,  1.5493, -0.8641,  2.0741,
          1.1380,  2.1096],
        [-2.0906, -2.7922, -0.7341, -2.5367, -0.8959,  1.8522, -1.0107,  2.0097,
          1.8937,  3.5179],
        [-2.1836, -2.0878, -0.8995, -1.9510, -1.1547,  2.1510, -1.1161,  3.2837,
          1.6127,  2.0370],
        [-1.4048, -1.5516, -0.2949, -1.4130, -0.4554,  1.2929, -0.5063,  1.6488,
          1.2723,  0.9827],
        [-1.9877, -1.7308, -0.8616, -1.7180, -1.0463,  1.9193, -1.0095,  3.1711,
          1.3404,  1.7089],
        [ 1.2420,  2.6799, -0.1519,  2.3143,  0.0886, -1.6762,  0.2106, -2.0601,
         -1.5670, -1.9904],
        [ 2.4438,  2.0520,  0.7078,  2.5691,  1.0891, -2.9709,  1.3198, -3.8169,
         -1.7511, -3.2686],
        [-2.0871, -3.0875,  0.1206, -2.6066, -0.0810,  1.3653, -0.3237,  1.5999,
          2.7416,  1.3769],
        [-1.0962, -

tensor([[ 8.0075e-01,  2.2265e+00, -6.9972e-01,  2.1928e+00, -3.4706e-01,
         -1.0856e+00, -1.9172e-01, -1.2682e+00, -1.0589e+00, -1.2720e+00],
        [-7.2558e-01, -2.0784e+00,  1.2148e+00, -1.7008e+00,  8.2127e-01,
          1.8727e-01,  6.4960e-01, -3.8748e-01,  1.2230e+00, -1.6305e-01],
        [ 5.8584e-01, -1.1488e+00,  1.5642e+00, -2.4335e-01,  1.6921e+00,
         -1.4915e+00,  1.4705e+00, -2.5305e+00,  3.7363e-01, -1.7853e+00],
        [ 2.7483e+00,  1.0882e+00,  7.6824e-01,  1.7516e+00,  8.9305e-01,
         -2.5466e+00,  1.4459e+00, -3.5326e+00, -1.4154e+00, -2.6655e+00],
        [-2.6890e+00, -3.1796e+00, -1.1951e+00, -2.8819e+00, -1.3621e+00,
          2.5657e+00, -1.3943e+00,  3.0742e+00,  2.1756e+00,  4.1474e+00],
        [ 1.5071e+00,  1.7462e+00, -3.5903e-01,  2.3517e+00,  1.0897e-01,
         -1.7623e+00,  4.2583e-01, -2.3202e+00, -8.8118e-01, -1.9054e+00],
        [ 1.5361e+00,  2.9283e+00, -5.3082e-01,  3.1128e+00,  1.4052e-05,
         -2.0660e+00,  1.5602e-0

tensor([[-1.3019, -1.7888, -0.1368, -1.5285, -0.2394,  1.1762, -0.3305,  0.9349,
          1.4009,  1.2411],
        [-2.0397, -1.8276, -0.8716, -1.7416, -1.0357,  2.0001, -0.9971,  3.1652,
          1.5282,  1.5519],
        [-2.1103, -2.1841, -0.9662, -1.9333, -1.1132,  2.2054, -1.0912,  3.3262,
          1.5218,  2.1491],
        [ 1.0645, -2.1040,  2.0582, -0.7929,  1.9915, -1.8004,  2.0126, -3.4279,
          0.8758, -1.9855],
        [-1.1786, -1.0923, -0.4589, -1.0801, -0.5452,  1.2890, -0.5177,  1.7285,
          0.8109,  0.9103],
        [-2.6015, -3.2472, -1.1581, -2.9816, -1.3612,  2.5489, -1.4399,  2.8093,
          2.0087,  4.6052],
        [ 2.4028,  2.5019,  0.2798,  3.1295,  0.8770, -2.9797,  1.0704, -3.7578,
         -1.9202, -3.1875],
        [-1.4753, -1.4776, -0.5563, -1.4068, -0.7049,  1.5493, -0.7227,  1.9202,
          1.1328,  1.3864],
        [-1.2585, -2.0028, -0.5773, -1.6525, -0.6299,  1.1789, -0.6025,  0.9077,
          1.1286,  2.7555],
        [ 1.9764,  

tensor([[-1.3136e+00, -1.5936e+00, -5.2856e-01, -1.4590e+00, -6.9918e-01,
          1.4551e+00, -6.6297e-01,  1.2999e+00,  1.1072e+00,  1.9132e+00],
        [ 3.5545e+00,  1.2817e+00,  9.9054e-01,  2.0515e+00,  9.6677e-01,
         -3.1571e+00,  1.7509e+00, -4.5421e+00, -1.6816e+00, -3.0588e+00],
        [ 2.7847e+00,  2.2779e+00,  4.8778e-01,  2.7716e+00,  7.8334e-01,
         -2.9856e+00,  1.2399e+00, -4.0586e+00, -1.8034e+00, -3.1639e+00],
        [ 2.1362e+00,  2.8848e+00,  1.5456e-01,  2.6727e+00,  4.4907e-01,
         -2.5116e+00,  7.5973e-01, -3.2527e+00, -1.8802e+00, -2.7543e+00],
        [-2.1760e+00, -3.5807e+00, -8.0064e-01, -2.9499e+00, -1.0442e+00,
          2.0624e+00, -1.1431e+00,  1.5323e+00,  2.0659e+00,  4.7856e+00],
        [ 2.4588e+00,  3.3096e+00, -3.5287e-02,  3.4493e+00,  4.9815e-01,
         -2.9842e+00,  7.4519e-01, -3.7332e+00, -2.2722e+00, -3.0105e+00],
        [ 6.4771e-01, -1.0120e+00,  1.7554e+00, -1.1372e-01,  1.7290e+00,
         -1.6369e+00,  1.5296e+0

tensor([[ 2.5133,  1.6309,  0.5502,  1.9838,  0.6199, -2.2791,  1.1066, -3.1402,
         -1.6992, -2.3551],
        [-2.3075, -2.2719, -0.8316, -2.0469, -1.0802,  2.3355, -1.0845,  3.4881,
          1.8409,  1.5821],
        [ 1.6906, -0.9093,  2.3776,  0.5829,  2.7848, -3.0569,  2.4233, -4.6472,
         -0.0653, -3.4838],
        [-0.0598, -0.8703,  1.0920, -0.5081,  0.9526, -0.4875,  0.6875, -1.1573,
          0.3732, -0.8108],
        [ 1.1397, -1.8605,  3.5463, -0.5537,  3.2818, -2.8718,  2.6275, -4.8085,
          0.2546, -3.1774],
        [-1.7495, -1.7636, -0.5451, -1.6172, -0.7301,  1.7569, -0.7400,  2.2795,
          1.4265,  1.3676],
        [ 0.4887, -2.2063,  3.2140, -1.1063,  3.0287, -2.2282,  2.2694, -3.9644,
          0.6603, -2.4934],
        [ 4.0788,  1.0439,  1.3023,  2.1727,  1.4459, -3.6298,  2.1991, -5.2716,
         -1.7650, -3.6457],
        [-2.4327, -2.6637, -1.1486, -2.3869, -1.3305,  2.6361, -1.3351,  3.2126,
          1.8485,  3.1072],
        [-2.0484, -

In [11]:
torch.save(model, './model.pth')

In [12]:
model = torch.load('model.pth').to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
tensor([[-3.0409e+00, -4.1448e+00, -1.0211e+00, -3.3991e+00, -1.1289e+00,
          2.7220e+00, -1.4238e+00,  2.5203e+00,  2.8514e+00,  4.8063e+00],
        [ 3.6793e+00,  2.9352e-01,  1.1727e+00,  1.5390e+00,  1.2051e+00,
         -3.0194e+00,  2.0353e+00, -4.6530e+00, -1.2751e+00, -2.8644e+00],
        [ 1.1023e+00,  1.6804e+00, -4.1790e-02,  1.6332e+00,  1.7854e-01,
         -1.2655e+00,  3.3381e-01, -1.7099e+00, -1.1206e+00, -1.5388e+00],
        [ 1.9184e+00,  1.3312e+00,  5.6435e-01,  1.6845e+00,  8.0868e-01,
         -2.0304e+00,  1.0674e+00, -2.9210e+00, -1.2030e+00, -2.2968e+00],
        [ 2.1140e+00,  2.9446e+00, -2.7909e-01,  3.3465e+00,  4.0417e-01,
         -2.6128e+00,  6.2761e-01, -3.3895e+00, -1.6943e+00, -2.8750e+00],
        [ 6.6157e-01, -1.8558e+00,  2.6068e+00, -8.2397e-01,  2.3068e+00,
         -1.8303e+00,  1.8973e+00, -3.4455e+00,  6.0563e-01, -2.1089e+00],
        [-2.1950e+00, -2.0939e+00, -9.8639e-01, -1.8590e+00, -1.11

tensor([[ 7.6767e-01,  5.1309e-01,  9.5841e-02,  7.8077e-01,  1.2168e-01,
         -6.1973e-01,  3.3632e-01, -1.0206e+00, -3.6564e-01, -1.0311e+00],
        [-3.0854e-01, -2.9531e+00,  2.7316e+00, -1.8531e+00,  2.5541e+00,
         -1.1968e+00,  1.8870e+00, -2.7046e+00,  1.4321e+00, -1.5310e+00],
        [ 2.0123e-01, -5.8116e-01,  6.1543e-01, -2.0210e-01,  6.5491e-01,
         -3.7566e-01,  6.5760e-01, -9.0229e-01,  1.6334e-01, -7.3677e-01],
        [-2.9170e+00, -4.4722e+00,  2.2291e-01, -3.4840e+00, -1.2171e-02,
          2.0872e+00, -3.8994e-01,  1.9961e+00,  3.7017e+00,  1.9126e+00],
        [ 4.9960e-01, -3.9449e-01,  1.1398e+00,  7.8797e-02,  1.2845e+00,
         -1.0843e+00,  1.0903e+00, -1.9301e+00, -5.5504e-02, -1.5113e+00],
        [-2.6555e+00, -2.4081e+00, -1.1474e+00, -2.1600e+00, -1.3422e+00,
          2.7536e+00, -1.3692e+00,  3.9133e+00,  1.8674e+00,  2.2176e+00],
        [-2.1076e-03, -1.5649e+00,  1.1437e+00, -8.4661e-01,  9.6887e-01,
         -5.1231e-01,  9.5538e-0

tensor([[-6.1442e-01, -5.3837e-01, -1.4039e-01, -4.4748e-01, -2.2975e-01,
          7.0748e-01, -1.9989e-01,  8.4446e-01,  5.0061e-01,  1.3164e-03],
        [-1.5139e+00, -2.2785e+00, -1.5380e-01, -1.8831e+00, -3.1427e-01,
          1.4371e+00, -4.1644e-01,  1.1143e+00,  1.5498e+00,  1.6992e+00],
        [ 1.4087e+00,  2.1699e+00, -6.8376e-01,  2.7279e+00,  2.9069e-02,
         -1.8616e+00,  3.5427e-01, -2.3793e+00, -1.1721e+00, -1.8610e+00],
        [ 4.3363e-01, -6.9350e-01,  1.0744e+00, -9.6535e-02,  1.1371e+00,
         -9.9074e-01,  1.1075e+00, -1.8487e+00,  2.5194e-01, -1.4287e+00],
        [-2.4582e+00, -3.9812e+00,  5.2452e-01, -2.7749e+00,  5.0974e-01,
          1.4206e+00,  8.6982e-02,  8.9927e-01,  3.3446e+00,  1.0675e+00],
        [-2.2919e+00, -2.2533e+00, -1.2774e+00, -2.0530e+00, -1.4188e+00,
          2.6591e+00, -1.3380e+00,  3.7405e+00,  1.3906e+00,  2.6762e+00],
        [-2.3339e+00, -3.4603e+00, -1.0256e+00, -2.6111e+00, -1.0081e+00,
          2.1348e+00, -1.0882e+0

tensor([[-0.1206, -1.7555,  1.9821, -0.9727,  1.8820, -0.9348,  1.4380, -2.0939,
          0.8017, -1.4681],
        [-2.0595, -1.9863, -1.1424, -1.8130, -1.2400,  2.2940, -1.1802,  2.7268,
          1.3686,  2.6452],
        [-2.6572, -3.6921, -1.0484, -3.0696, -1.2212,  2.7936, -1.3313,  2.5210,
          2.3106,  4.5027],
        [-2.9365, -2.7062, -1.3277, -2.3493, -1.5356,  3.1590, -1.5574,  4.4108,
          1.9839,  2.5142],
        [-1.8599, -2.0489, -0.4208, -1.7166, -0.5672,  1.8745, -0.6737,  2.1401,
          1.6192,  1.1553],
        [-2.6852, -2.2054, -1.2657, -2.0391, -1.3946,  2.8249, -1.4173,  4.2733,
          1.6331,  2.0592],
        [ 1.4395,  3.4822, -0.2752,  2.9606,  0.1696, -2.0659,  0.2801, -2.6066,
         -2.0291, -2.3925],
        [ 3.0713,  2.4413,  0.8847,  3.3009,  1.4807, -3.7201,  1.8141, -5.0083,
         -2.1274, -4.0337],
        [-2.6120, -4.1464,  0.1873, -3.1039, -0.0150,  1.9937, -0.3246,  1.8203,
          3.5795,  1.4459],
        [-1.2347, -

tensor([[ 0.7971,  2.7165, -1.0155,  2.7498, -0.4125, -1.2040, -0.2843, -1.4458,
         -1.2889, -1.3932],
        [-0.7714, -2.6319,  1.6352, -1.9595,  1.0926,  0.2786,  0.9273, -0.7710,
          1.4950, -0.4059],
        [ 0.7536, -1.6038,  2.0453, -0.2086,  2.1980, -1.7956,  1.9636, -3.4023,
          0.5763, -2.2742],
        [ 3.5539,  1.1301,  0.9874,  2.1792,  1.1329, -3.1469,  1.9743, -4.6116,
         -1.6788, -3.2089],
        [-3.3792, -4.0666, -1.6534, -3.3939, -1.7774,  3.6461, -1.8271,  3.8287,
          2.5821,  5.2213],
        [ 1.7450,  1.9023, -0.5585,  2.9138,  0.1767, -2.0361,  0.5754, -2.8552,
         -0.9365, -2.1687],
        [ 1.6896,  3.4691, -0.8172,  3.8536,  0.0554, -2.4024,  0.1938, -2.9863,
         -2.0375, -2.3118],
        [ 3.4409,  1.1292,  1.0560,  2.6463,  1.6723, -3.7104,  2.2426, -5.2692,
         -1.5096, -3.8493],
        [ 1.3328,  3.5954, -0.4043,  2.7164,  0.0552, -1.8711,  0.2388, -2.4245,
         -1.9698, -2.2471],
        [-1.3790, -

tensor([[ 1.4710,  3.5917, -0.1641,  2.7301,  0.3235, -2.0781,  0.4201, -2.7094,
         -1.9858, -2.5786],
        [-3.4076, -5.7555,  0.1746, -4.1942, -0.0694,  2.7839, -0.4449,  2.4861,
          4.6395,  2.1996],
        [-0.1989, -1.4334,  1.2839, -0.9411,  1.1154, -0.2861,  0.8981, -1.1340,
          0.8034, -0.8448],
        [-1.9676, -1.7656, -0.8999, -1.4580, -0.9810,  2.1276, -0.9732,  3.2413,
          1.3625,  1.1748],
        [ 1.0398, -2.1039,  2.8138, -0.1973,  3.1157, -2.5957,  2.6591, -4.6913,
          0.8101, -3.0558],
        [ 2.0724,  2.8728, -0.5519,  3.8081,  0.4024, -2.7199,  0.6581, -3.5473,
         -1.6215, -2.8939],
        [ 0.3726, -3.4404,  3.9432, -1.6787,  3.5703, -2.3607,  2.8615, -4.7301,
          1.4787, -2.8550],
        [ 1.5268,  3.9479, -0.2665,  2.8843,  0.2714, -2.1610,  0.4018, -2.8465,
         -2.1005, -2.6969],
        [ 1.0020, -3.0500,  2.9658, -0.9315,  2.9707, -2.2409,  2.7766, -4.5922,
          1.3848, -2.6463],
        [-0.8044, -