In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


In [4]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [5]:
model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [6]:
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

Predicted class: tensor([5], device='mps:0')


In [7]:
input_image = torch.rand(3,28,28)
print(input_image.size())

torch.Size([3, 28, 28])


In [8]:
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())

torch.Size([3, 784])


In [12]:
layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())

torch.Size([3, 20])


In [13]:
print(f"Before ReLU:\n{hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU:\n{hidden1}")

Before ReLU:
tensor([[ 0.1854,  0.3809,  0.3456, -0.1637,  0.0711, -0.0115, -0.4402,  0.1982,
          0.2651, -0.1735, -0.2796,  0.1575, -0.2278, -0.3386, -0.1113, -0.0799,
          0.1830, -0.1368, -0.1551,  0.0797],
        [-0.1550,  0.1638,  0.3078,  0.1729,  0.2166,  0.1950, -0.3863,  0.1925,
          0.2642, -0.2710, -0.2435,  0.0754, -0.3075, -0.6059, -0.2064, -0.0512,
          0.0686, -0.2877, -0.1517,  0.0331],
        [ 0.0155,  0.1021,  0.2586,  0.0730,  0.3154,  0.2574, -0.4803,  0.4183,
          0.4274, -0.5746, -0.3251,  0.2643, -0.5747, -0.1880, -0.4754, -0.0236,
          0.2875, -0.4868,  0.0820,  0.0721]], grad_fn=<AddmmBackward0>)


After ReLU:
tensor([[0.1854, 0.3809, 0.3456, 0.0000, 0.0711, 0.0000, 0.0000, 0.1982, 0.2651,
         0.0000, 0.0000, 0.1575, 0.0000, 0.0000, 0.0000, 0.0000, 0.1830, 0.0000,
         0.0000, 0.0797],
        [0.0000, 0.1638, 0.3078, 0.1729, 0.2166, 0.1950, 0.0000, 0.1925, 0.2642,
         0.0000, 0.0000, 0.0754, 0.0000, 0.0000, 0.00

In [14]:
seq_modules = nn.Sequential(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Linear(20, 10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)

In [15]:
softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)

In [16]:
print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

Model structure: NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784]) | Values : tensor([[-0.0062,  0.0310,  0.0208,  ...,  0.0285,  0.0328,  0.0305],
        [ 0.0321, -0.0253, -0.0178,  ...,  0.0149, -0.0255,  0.0186]],
       device='mps:0', grad_fn=<SliceBackward0>) 

Layer: linear_relu_stack.0.bias | Size: torch.Size([512]) | Values : tensor([-0.0089, -0.0061], device='mps:0', grad_fn=<SliceBackward0>) 

Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512]) | Values : tensor([[ 0.0081, -0.0164,  0.0347,  ..., -0.0050, -0.0069,  0.0196],
        [-0.0219,  0.0051, -0.0091,  ..., -0.0186, -0.0088, -0.0036]],
       device='mps:0', grad_fn=<Slice