In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms

In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])

In [3]:
batch_size = 64

In [4]:
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

In [5]:
class FashionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 256)
        self.hidden2 = nn.Linear(256, 128)
        self.output = nn.Linear(128, 10)
        self.softmax = nn.Softmax()
        self.activation = nn.ReLU()
    def forward(self, x):
        x = self.hidden1(x)
        x = self.activation(x)
        x = self.hidden2(x)
        x = self.activation(x)
        x = self.output(x)
        output = self.softmax(x)
        return output

In [6]:
model = FashionNetwork()

In [7]:
print(model)

FashionNetwork(
  (hidden1): Linear(in_features=784, out_features=256, bias=True)
  (hidden2): Linear(in_features=256, out_features=128, bias=True)
  (output): Linear(in_features=128, out_features=10, bias=True)
  (softmax): Softmax()
  (activation): ReLU()
)


In [9]:
model.hidden1.weight

Parameter containing:
tensor([[ 0.0051,  0.0109,  0.0304,  ...,  0.0087, -0.0166,  0.0081],
        [ 0.0142,  0.0108, -0.0318,  ...,  0.0108,  0.0090, -0.0317],
        [-0.0169, -0.0322,  0.0043,  ...,  0.0006,  0.0030,  0.0154],
        ...,
        [-0.0122, -0.0053,  0.0026,  ..., -0.0260,  0.0298, -0.0224],
        [-0.0060,  0.0077, -0.0282,  ...,  0.0131, -0.0032, -0.0195],
        [ 0.0169, -0.0032,  0.0209,  ..., -0.0286,  0.0186, -0.0037]],
       requires_grad=True)