In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F 

In [2]:
class AlexNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(6400, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, X):
        X = self.pool1(F.relu(self.conv1(X)))
        X = self.pool2(F.relu(self.conv2(X)))
        X = F.relu(self.conv3(X))
        X = F.relu(self.conv4(X))
        X = self.pool3(F.relu(self.conv5(X)))
        X = self.flatten(X)
        X = F.relu(self.fc1(X))
        X = F.dropout(X, p=0.5)
        X = F.relu(self.fc2(X))
        X = F.dropout(X, p=0.5)
        X = self.fc3(X)
        return X

In [3]:
trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(224)
])
train_data = torchvision.datasets.FashionMNIST(
    root='../data', train=True, transform=trans, download=False
)
test_data = torchvision.datasets.FashionMNIST(
    root='../data', train=False, transform=trans, download=False
)
print('The number of training data:', len(train_data))
print('The number of test data:', len(test_data))

The number of training data: 60000
The number of test data: 10000


In [4]:
batch_size = 64
trian_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
for X,y in test_dataloader:
    print('The shape of X:', X.shape)
    print('The shape of y:', y.shape)
    break

The shape of X: torch.Size([64, 1, 224, 224])
The shape of y: torch.Size([64])


In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = AlexNet().to(device)
loss = nn.CrossEntropyLoss()
trainer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [6]:
def train(net, dataloader, loss, trainer):
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        trainer.zero_grad()
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        trainer.step()
        if batch%100 == 0:
            training_loss = l.item()
            current_batch = batch * len(X)
            print('Trianing loss: %.4f\t[%d/%d]' % (training_loss, current_batch, len(dataloader.dataset)))


In [7]:
def test(net, dataloader, loss):
    num_batches = len(dataloader)
    test_loss = 0.0
    accuracy = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            test_loss += loss(y_hat, y)
            accuracy += (y_hat.argmax(dim=1)==y).type(torch.float32).sum().item()
    test_loss /= num_batches
    accuracy /= len(dataloader.dataset)
    print('Test loss: %.4f\tTest accuracy: %.4f' % (test_loss, accuracy))

In [8]:
EPOCHS = 5
print('training on', device)
for epoch in range(EPOCHS):
    print('Epoch %d' % (epoch+1))
    train(net, trian_dataloader, loss, trainer)
    test(net, test_dataloader, loss)
print('Done!')

training on cuda:0
Epoch 1
Trianing loss: 2.3042	[0/60000]
Trianing loss: 0.7100	[6400/60000]
Trianing loss: 0.3935	[12800/60000]
Trianing loss: 0.5874	[19200/60000]
Trianing loss: 0.5461	[25600/60000]
Trianing loss: 0.4507	[32000/60000]
Trianing loss: 0.4712	[38400/60000]
Trianing loss: 0.6649	[44800/60000]
Trianing loss: 0.4081	[51200/60000]
Trianing loss: 0.4042	[57600/60000]
Test loss: 0.4040	Test accuracy: 0.8486
Epoch 2
Trianing loss: 0.3396	[0/60000]
Trianing loss: 0.3733	[6400/60000]
Trianing loss: 0.3204	[12800/60000]
Trianing loss: 0.4584	[19200/60000]
Trianing loss: 0.3872	[25600/60000]
Trianing loss: 0.3715	[32000/60000]
Trianing loss: 0.2932	[38400/60000]
Trianing loss: 0.5631	[44800/60000]
Trianing loss: 0.4125	[51200/60000]
Trianing loss: 0.2918	[57600/60000]
Test loss: 0.3623	Test accuracy: 0.8662
Epoch 3
Trianing loss: 0.2629	[0/60000]
Trianing loss: 0.3396	[6400/60000]
Trianing loss: 0.2523	[12800/60000]
Trianing loss: 0.4614	[19200/60000]
Trianing loss: 0.3777	[25600