In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm

In [2]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms

In [3]:
train = FashionMNIST('/fashionmnist', train = True, download = True, transform = transforms.ToTensor())
test = FashionMNIST('/fashionmnist', train = False, download = True, transform = transforms.ToTensor())

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /fashionmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /fashionmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to /fashionmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /fashionmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /fashionmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /fashionmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /fashionmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /fashionmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /fashionmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /fashionmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /fashionmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /fashionmnist/FashionMNIST/raw



In [4]:
bs = 64

In [5]:
tr_loader = DataLoader(train, batch_size = bs, shuffle = True)
te_loader = DataLoader(test, batch_size = bs, shuffle = True)

In [6]:
class Flatten_Layer(nn.Module) :
  def forward(self, x) :
    sizes = x.size()

    return x.view(sizes[0], -1)

In [7]:
conv_net = nn.Sequential(
    nn.Conv2d(1, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(.3),
    nn.Conv2d(64, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(.2),
    Flatten_Layer()
)

In [8]:
te_input = torch.ones(1, 1, 28, 28)
conv_output_size = conv_net(te_input).size()[-1]

In [9]:
mlp = nn.Sequential(
    nn.Linear(conv_output_size, 128),
    nn.ReLU(),
    nn.BatchNorm1d(128),
    nn.Dropout(.25),
    nn.Linear(128, 10)
)

In [10]:
net = nn.Sequential(conv_net, mlp)

In [11]:
def eval_net(net, data_loader, device = 'cuda:0') :
  net.eval()
  ys = []
  y_preds = []

  for x, y in data_loader :
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad() :
      _, y_pred = net(x).max(1)
    ys.append(y)
    y_preds.append(y_pred)
  
  ys = torch.cat(ys)
  y_preds = torch.cat(y_preds)

  acc = (ys == y_preds).float().sum() / len(ys)

  return acc.item()

In [12]:
def train_net(net, train_loader, test_loader, optimizer = optim.Adam, loss_fn = nn.CrossEntropyLoss(), n_iter = 10, device = 'cuda:0') :

  tr_losses = []
  tr_acc = []
  val_acc = []
  optimizer = optimizer(net.parameters())

  for epoch in range(n_iter) :

    running_loss = 0
    net.train()
    n = 0
    acc = 0

    for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader), total = len(train_loader)) :
      xx = xx.to(device)
      yy = yy.to(device)

      h = net(xx)
      loss = loss_fn(h, yy)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      n += len(xx)
      _, y_pred = h.max(1)
      acc += (yy == y_pred).float().sum().item()
    tr_losses.append(running_loss / i)
    tr_acc.append(acc / n)

    val_acc.append(eval_net(net, test_loader, device))

    print(epoch, tr_losses[-1], tr_acc[-1], val_acc[-1], flush = True)

In [13]:
net.to('cuda:0')

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Dropout2d(p=0.3, inplace=False)
    (5): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): ReLU()
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Dropout2d(p=0.2, inplace=False)
    (10): Flatten_Layer()
  )
  (1): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [None]:
train_net(net, tr_loader, te_loader, n_iter = 20, device = 'cuda:0')

100%|██████████| 938/938 [00:17<00:00, 54.36it/s]


0 0.4334274571122584 0.8436666666666667 0.8855999708175659


100%|██████████| 938/938 [00:09<00:00, 99.86it/s]


1 0.31379523994987524 0.88475 0.8976999521255493


100%|██████████| 938/938 [00:09<00:00, 99.81it/s] 


2 0.28070970502231773 0.8956 0.8955000042915344


100%|██████████| 938/938 [00:09<00:00, 100.80it/s]


3 0.2597260402720282 0.9044833333333333 0.9047999978065491


100%|██████████| 938/938 [00:09<00:00, 100.52it/s]


4 0.24273627452162694 0.9096333333333333 0.9095999598503113


100%|██████████| 938/938 [00:09<00:00, 101.22it/s]


5 0.2320371137548282 0.9147333333333333 0.9107999801635742


100%|██████████| 938/938 [00:09<00:00, 101.65it/s]


6 0.2221260939107506 0.9172 0.915399968624115


100%|██████████| 938/938 [00:09<00:00, 96.37it/s]


7 0.21211149889415204 0.9222666666666667 0.9122999906539917


100%|██████████| 938/938 [00:09<00:00, 102.14it/s]


8 0.2023742173737752 0.9251666666666667 0.9115999937057495


100%|██████████| 938/938 [00:09<00:00, 102.96it/s]


9 0.19610185265588786 0.9276 0.9192000031471252


100%|██████████| 938/938 [00:09<00:00, 103.08it/s]


10 0.19068603722523728 0.9282166666666667 0.9138000011444092


100%|██████████| 938/938 [00:09<00:00, 101.81it/s]


11 0.1828310275798228 0.9318 0.9210999608039856


100%|██████████| 938/938 [00:09<00:00, 102.31it/s]


12 0.17616237088314307 0.9336333333333333 0.9197999835014343


100%|██████████| 938/938 [00:09<00:00, 102.95it/s]


13 0.17333801955048278 0.9355166666666667 0.9204999804496765


100%|██████████| 938/938 [00:09<00:00, 101.39it/s]


14 0.16739496400056997 0.9389 0.92249995470047


100%|██████████| 938/938 [00:09<00:00, 103.09it/s]


15 0.16208654864549063 0.93905 0.9208999872207642


100%|██████████| 938/938 [00:09<00:00, 102.62it/s]


16 0.1583249273301443 0.9403 0.9172999858856201


100%|██████████| 938/938 [00:09<00:00, 102.28it/s]


17 0.15685327319088396 0.9407666666666666 0.9217000007629395


 20%|█▉        | 186/938 [00:01<00:07, 102.24it/s]