In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import torchsummary

## **Data**
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-iamges-idx3-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-iamges-idx3-ubyte.gz  
http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-iamges-idx1-ubyte.gz  
`{root}\FashionMNIST\raw`

In [2]:
trans = transforms.Compose([transforms.Resize((224, 224)),  # upscale
                            transforms.ToTensor()])

data_train = torchvision.datasets.FashionMNIST(
    root='./data', train=True, transform=trans, download=False 
)
data_val = torchvision.datasets.FashionMNIST(
    root='./data', train=False, transform=trans, download=False
)

In [3]:
data_train

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [4]:
data_val

Dataset FashionMNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [5]:
image, label = data_train[0]  # [image, label]
print(image.shape) # (channel, height, weight)
print(label)

torch.Size([1, 224, 224])
9


## **Alexnet**

In [6]:
class Alexnet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), nn.Flatten(),
            nn.LazyLinear(out_features=4096), nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096), nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=num_classes)
        )

    def forward(self, X):  # X.shape =(batch_size, channel, height, width)
        return self.net(X)

In [7]:
model = Alexnet()



In [8]:
torchsummary.summary(model, input_size=(1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          11,712
              ReLU-2           [-1, 96, 54, 54]               0
         MaxPool2d-3           [-1, 96, 26, 26]               0
            Conv2d-4          [-1, 256, 26, 26]         614,656
              ReLU-5          [-1, 256, 26, 26]               0
         MaxPool2d-6          [-1, 256, 12, 12]               0
            Conv2d-7          [-1, 384, 12, 12]         885,120
              ReLU-8          [-1, 384, 12, 12]               0
            Conv2d-9          [-1, 384, 12, 12]       1,327,488
             ReLU-10          [-1, 384, 12, 12]               0
           Conv2d-11          [-1, 256, 12, 12]         884,992
             ReLU-12          [-1, 256, 12, 12]               0
        MaxPool2d-13            [-1, 256, 5, 5]               0
          Flatten-14                 [-

## **Training**

In [9]:
batch_size = 128

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [10]:
model = Alexnet()

In [11]:
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)

In [12]:
def accuracy(y_hat, y):
    # y_hat: (B, q)
    # y: (B)
    preds = y_hat.argmax(axis=1).type(y.dtype)  # (B)
    compare = (preds == y).type(torch.float32)  # (B)
    return compare.sum()

In [None]:
%%time
for i in range(10):
    model.train()

    train_loss = 0
    num_train_batches = 0
    for b, (X, y) in enumerate(train_loader):
        optimizer.zero_grad()
        y_hat = model(X)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        num_train_batches += 1
        if b % 10 == 0:
            print(f'epoch={i} | batch={b} | train_loss={train_loss/num_train_batches:.4f}')

    model.eval()
    with torch.no_grad():
        val_loss = 0
        num_val_batches = 0
        val_acc = 0
        total = 0
        for X, y in val_loader:
            y_hat = model(X)
            loss = F.cross_entropy(y_hat, y)
            val_loss += loss.item()
            num_val_batches += 1
            val_acc += accuracy(y_hat, y)
            total += y.numel()
        
    print(f'epoch={i} | train_loss={train_loss/num_train_batches:.4f} | val_loss={val_loss/num_val_batches:.4f} | val_acc={val_acc/total:.4f}')