In [23]:
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [19]:
# Setup training data
train_data = datasets.FashionMNIST(
    root="data", # where to download data to?
    train=True, # get training data
    download=True, # download data if it doesn't exist on disk
    transform=ToTensor(), # images come as PIL format, we want to turn into Torch tensors
    target_transform=None # you can transform labels as well
)

# Setup testing data
test_data = datasets.FashionMNIST(
    root="data",
    train=False, # get test data
    download=True,
    transform=ToTensor()
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


  return torch.from_numpy(parsed).view(length, num_rows, num_cols)


In [24]:
# Setup the batch size hyperparameter
BATCH_SIZE = 32

# Turn datasets into iterables (batches)
train_dataloader = DataLoader(train_data, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=True # shuffle data every epoch?
)

test_dataloader = DataLoader(test_data,
    batch_size=BATCH_SIZE,
    shuffle=False # don't necessarily have to shuffle the testing data
)

# Let's check out what we've created
print(f"Dataloaders: {train_dataloader, test_dataloader}") 
print(f"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}")

Dataloaders: (<torch.utils.data.dataloader.DataLoader object at 0x7fadcc730b50>, <torch.utils.data.dataloader.DataLoader object at 0x7fadcc7306a0>)
Length of train dataloader: 1875 batches of 32
Length of test dataloader: 313 batches of 32


In [49]:
class TorchVision(nn.Module):
    def __init__(self):
        super().__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels = 1,
                      out_channels=10,
                      kernel_size = 3,
                      stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels = 10,
                      out_channels=10,
                      kernel_size = 3,
                      stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(10*7*7, 10),
            nn.ReLU(),
            nn.Linear(10, len(train_data.classes))
        )
        
    def forward(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        # x = self.linear(x)
        return x

In [50]:
model = TorchVision()

In [51]:
torch.manual_seed(42)

# Create sample batch of random numbers with same size as image batch
images = torch.randn(size=(32, 1, 64, 64)) # [batch_size, color_channels, height, width]
test_image = images[0] 

In [52]:
test_image.shape

torch.Size([1, 64, 64])

In [53]:
model(test_image)

tensor([[[0.3248, 0.2961, 0.5884,  ..., 0.3674, 0.7405, 0.2887],
         [0.5716, 0.3999, 0.6613,  ..., 0.2170, 0.4192, 0.4345],
         [0.3965, 0.5439, 0.5014,  ..., 0.6156, 0.2756, 0.1799],
         ...,
         [0.5585, 0.4988, 0.6639,  ..., 0.4348, 0.3700, 0.4278],
         [0.3271, 0.5307, 0.4942,  ..., 0.5747, 0.4504, 0.2288],
         [0.4078, 0.6205, 0.6547,  ..., 0.4013, 0.5744, 0.5313]],

        [[0.2417, 0.2717, 0.3061,  ..., 0.3539, 0.1219, 0.0000],
         [0.0968, 0.0054, 0.4220,  ..., 0.2191, 0.2372, 0.0958],
         [0.3004, 0.1026, 0.3374,  ..., 0.1954, 0.4295, 0.2660],
         ...,
         [0.1222, 0.1351, 0.7613,  ..., 0.2414, 0.2301, 0.4514],
         [0.3146, 0.5164, 0.4512,  ..., 0.1715, 0.0000, 0.1450],
         [0.0867, 0.0888, 0.4363,  ..., 0.1313, 0.7915, 0.2709]],

        [[0.5177, 0.4038, 0.1598,  ..., 0.4074, 0.0000, 0.4368],
         [0.0997, 0.1811, 0.1872,  ..., 0.3580, 0.4082, 0.6276],
         [0.5486, 0.7778, 0.3136,  ..., 0.5472, 0.3442, 0.