In [1]:
import warnings
from google.colab import drive

drive.mount("/content/drive")
warnings.filterwarnings("ignore")

MODEL_PATH = "/content/drive/MyDrive/model.pth"
DATA_PATH = "/content/drive/MyDrive/MNIST"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from torch.utils.data import random_split

BATCH_SIZE = 32
LABELS = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot",
}
LR = 0.0001
N_EPOCH = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Current device:", DEVICE)

Current device: cuda


In [3]:
# resize to match MobileNetV2 input
transform = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Load FashionMNIST Dataset
train_dataset = datasets.FashionMNIST(
    root=DATA_PATH, train=True, transform=transform, download=True
)
test_dataset = datasets.FashionMNIST(
    root=DATA_PATH, train=False, transform=transform, download=True
)

train_size = int(0.9 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
model = models.mobilenet_v2(weights=None)
model.classifier[1] = nn.Linear(model.last_channel, len(LABELS))
model.load_state_dict(torch.load(MODEL_PATH))
model = model.to(DEVICE)
# summary(model, (3, 224, 224))

In [5]:
class EarlyStopping(object):
    def __init__(self, patience, save_path):
        self._min_loss = np.inf
        self._patience = patience
        self._path = save_path
        self.__counter = 0

    def should_stop(self, model, loss):
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False

    def load(self, model):
        model.load_state_dict(torch.load(self._path))
        return model

    @property
    def counter(self):
        return self.__counter

In [6]:
early_stopper = EarlyStopping(2, MODEL_PATH)
criterion = nn.CrossEntropyLoss()
criterion.to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)

In [7]:
def evaluate(model, criterion, valid_loader, device):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)

            loss = criterion(outputs, labels)
            val_loss += loss.item()

        return val_loss / len(valid_loader)


def train(model, optimizer, criterion, train_loader, valid_loader, device):
    for epoch in range(N_EPOCH):
        model.train()
        train_loss = 0

        for batch_id, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss = evaluate(model, criterion, valid_loader, device)
        print(f"[{epoch + 1}] Train: {train_loss:.5f} | Validation: {val_loss:.5f}")

        if early_stopper.should_stop(model, val_loss):
            print(f"EarlyStopping: [Epoch: {epoch - early_stopper.counter + 1}]")
            break

In [8]:
train(model, optimizer, criterion, train_loader, valid_loader, DEVICE)

[1] Train: 0.09615 | Validation: 0.07122
[2] Train: 0.07905 | Validation: 0.06717
[3] Train: 0.07033 | Validation: 0.06387
[4] Train: 0.06384 | Validation: 0.06231
[5] Train: 0.05764 | Validation: 0.06196
[6] Train: 0.05107 | Validation: 0.06074
[7] Train: 0.04638 | Validation: 0.06033
[8] Train: 0.04151 | Validation: 0.06148
[9] Train: 0.03675 | Validation: 0.06153
EarlyStopping: [Epoch: 7]


## Evaluate

In [9]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)

        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}")

Accuracy: 94.36
