In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3,padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3,padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3,padding=1)
        self.fc1 = nn.Linear(7*7*64, 600)
        self.fc2 = nn.Linear(600, 120)
        self.fc3 = nn.Linear(120, 10)
        self.drop_layer = nn.Dropout(p=0.2)

    def last_hidden_layer_output(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = self.drop_layer(F.relu(self.conv3(x)))
        x = self.drop_layer(F.relu(self.conv4(x)))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def forward(self, x):
        x = self.last_hidden_layer_output(x)
        x = self.fc3(x)
        return x

batch_size = 64

fashion_train = datasets.FashionMNIST("FashionData", train=True, download=True, transform=transforms.ToTensor())
fashion_test = datasets.FashionMNIST("FashionData", train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(fashion_train, batch_size=64, shuffle=True)
test_loader = DataLoader(fashion_test, batch_size=64, shuffle=False)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

torch.manual_seed(2)

softmax = nn.Softmax(dim=1)

learning_rate = 0.01

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 10 epochs"""
    lr = learning_rate * (0.5 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()

def epoch(loader, model, opt=None):

    if opt:
        model.train()
    else:
        model.eval()
    

    total_loss, total_err = 0., 0.

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        yp = model(X)
        loss = F.nll_loss(F.log_softmax(yp, dim=1), y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()

        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

model_cnn = CNN()
model_cnn = model_cnn.to(device)

opt = optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.9)

for t in range(50):
    adjust_learning_rate(opt, t)
    train_err, train_loss = epoch(train_loader, model_cnn, opt)
    test_err, test_loss = epoch(test_loader, model_cnn)
    print(*("{:.6f}".format(i) for i in (train_err, test_err)), sep="\t")

torch.save(model_cnn.state_dict(), "model_cnn_mnist_fashion.pt")

model_cnn.load_state_dict(torch.load("model_cnn_mnist_fashion.pt"))
model_cnn.eval()


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to FashionData/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting FashionData/FashionMNIST/raw/train-images-idx3-ubyte.gz to FashionData/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to FashionData/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting FashionData/FashionMNIST/raw/train-labels-idx1-ubyte.gz to FashionData/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to FashionData/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting FashionData/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to FashionData/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to FashionData/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting FashionData/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to FashionData/FashionMNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0.448250	0.205100
0.159667	0.143700
0.131400	0.120200
0.115500	0.112500
0.104283	0.112000
0.098033	0.098300
0.091600	0.097700
0.085800	0.096200
0.081367	0.091900
0.077633	0.089300
0.067717	0.083900
0.064967	0.080900
0.062133	0.077900
0.059833	0.086600
0.057000	0.080500
0.054933	0.082500
0.052800	0.079800
0.050983	0.078800
0.049017	0.077900
0.045883	0.080000
0.038700	0.077700
0.036783	0.074600
0.036017	0.078500
0.034617	0.075700
0.032733	0.079800
0.031400	0.081100
0.030700	0.075900
0.030167	0.074500
0.028833	0.076100
0.027567	0.079000
0.023333	0.074500
0.022133	0.073400
0.020717	0.076600
0.020133	0.073700
0.020050	0.075600
0.019200	0.075300
0.018600	0.075800
0.018200	0.074200
0.017583	0.076600
0.017983	0.078200
0.015100	0.073900
0.013950	0.074500
0.014017	0.074300
0.013733	0.073700
0.012717	0.075500
0.012517	0.073600
0.013467	0.075000
0.011433	0.072500
0.012983	0.075500
0.012317	0.073700


CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=3136, out_features=600, bias=True)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (drop_layer): Dropout(p=0.2, inplace=False)
)