In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim

import matplotlib.pyplot as plt
plt.rc('font', size=16)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## FashionMNIST dataset

In [2]:
data_dir = '/Users/hhg/Research/machine_learning/data/'
fmnist_train = datasets.FashionMNIST(
    root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
fmnist_test = datasets.FashionMNIST(
    root=data_dir, train=False, transform=transforms.ToTensor(), download=True)

In [3]:
batch_size = 128

train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True)

In [4]:
criterion = nn.CrossEntropyLoss()

In [5]:
images, labels = next(iter(train_loader))
images.shape

torch.Size([128, 1, 28, 28])

In [6]:
c0 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.Sigmoid())
cB = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid())

out0 = c0(images)
outB = cB(images)

print(out0.shape)
print(outB.shape)

torch.Size([128, 6, 24, 24])
torch.Size([128, 6, 24, 24])


In [7]:
p0 = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2))
pB = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2))

out0 = p0(out0)
outB = pB(outB)

print(out0.shape)
print(outB.shape)

torch.Size([128, 6, 12, 12])
torch.Size([128, 6, 12, 12])


In [8]:
c20 = nn.Sequential(nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid())
c2B = nn.Sequential(nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid())

out0 = c20(out0)
outB = c2B(outB)

print(out0.shape)
print(outB.shape)

torch.Size([128, 16, 8, 8])
torch.Size([128, 16, 8, 8])


In [9]:
p20 = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2))
p2B = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2))

out0 = p20(out0)
outB = p2B(outB)

print(out0.shape)
print(outB.shape)

torch.Size([128, 16, 4, 4])
torch.Size([128, 16, 4, 4])


In [10]:
f0 = nn.Flatten()
fB = nn.Flatten()

out0 = f0(out0)
outB = fB(outB)

print(out0.shape)
print(outB.shape)

torch.Size([128, 256])
torch.Size([128, 256])
