In [37]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [43]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)


NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)


In [44]:
X = torch.rand(1, 28, 28, device=device)
logits = model(X) 
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

Predicted class: tensor([4], device='cuda:0')


In [45]:
print(f"First Linear weights: {model.linear_relu_stack[0].weight} \n")

print(f"First Linear weights: {model.linear_relu_stack[0].bias} \n")

First Linear weights: Parameter containing:
tensor([[-0.0297, -0.0260,  0.0074,  ...,  0.0272, -0.0149, -0.0315],
        [-0.0296,  0.0112,  0.0138,  ..., -0.0003,  0.0298,  0.0237],
        [ 0.0068,  0.0220, -0.0161,  ...,  0.0206,  0.0108, -0.0067],
        ...,
        [-0.0082,  0.0343,  0.0086,  ..., -0.0318, -0.0265, -0.0214],
        [ 0.0271, -0.0088,  0.0028,  ...,  0.0256, -0.0049,  0.0064],
        [-0.0050,  0.0089, -0.0325,  ...,  0.0237, -0.0010,  0.0111]],
       device='cuda:0', requires_grad=True) 

First Linear weights: Parameter containing:
tensor([-1.4830e-02, -5.5940e-03,  2.3842e-02, -1.4581e-02,  1.4766e-02,
        -2.1641e-02, -4.5917e-03, -3.4895e-02,  2.6124e-02,  6.0647e-03,
        -6.2955e-03, -9.9045e-03,  3.3280e-03, -3.3120e-02, -2.6642e-02,
        -2.5042e-02,  3.1486e-02,  6.2855e-03,  9.0678e-03,  2.3668e-02,
        -2.2036e-02,  3.4410e-02,  1.9549e-02, -8.4453e-03,  3.0602e-02,
         1.8247e-02, -3.3615e-02,  1.8675e-02, -5.7868e-03, -2.0086