In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import *

In [2]:
from torchvision import transforms

transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor()
])

In [3]:
mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 13250925.26it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 209016.24it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:07<00:00, 598769.18it/s] 


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 13683318.75it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [4]:
batch_size=64
trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [5]:
device="cuda"

In [6]:
MobileNet = mobilenet_v2(pretrained = True)
MobileNet = resnet34(pretrained = True)

for param in MobileNet.parameters():
    param.requires_grad = False

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 36.1MB/s]
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:01<00:00, 81.1MB/s]


The MNIST dataset has input size of (28, 28, 1), but the pretrained mobile net v2 accepts input shape of (224, 224, 3).
Instead of just using Resize and interpolate, which due to the large size difference can cause the final image to be very blurry, I used FCN ideas of deconvolution to upsample the image.

In [7]:
class FCN_MobileNet(nn.Module):
    def __init__(self, pretrained, num_classes=10):
        super().__init__()
        self.convT1 = nn.ConvTranspose2d(1, 32, kernel_size=(3,3), stride=2)
        self.convT2 = nn.ConvTranspose2d(32, 32, kernel_size=(5,5), stride=4)
        self.conv1 = nn.Conv2d(32, 3, kernel_size=(6,6))
        self.pretrained = pretrained
        self.fc1 = nn.LazyLinear(512)
        self.fc2 = nn.Linear(512, num_classes)
        for param in self.pretrained.parameters():
            param.requires_grad = False
    def forward(self, x):
        x = self.convT1(x)
        x = self.convT2(x)
        x = self.conv1(x)
        x = self.pretrained(x)
        x = F.tanh(self.fc1(x))
        x = self.fc2(x)
#         print(x.shape)
        return x

In [8]:
funni_net = FCN_MobileNet(MobileNet).to(device)



In [9]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(funni_net.parameters(), lr=0.0005) # take note to pass in the correct network for network.parameters()

In [None]:
for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = funni_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 99 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 99:.10f}')
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 0.7458684804
[1,   200] loss: 0.5284394974


In [None]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        # calculate outputs by running images through the network
        outputs = funni_net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')