In [None]:
import time
import copy
import random

import torch
import torch.nn as nn

import torchvision

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

In [None]:
epochs = 10
batch = 16
lr = 0.001
nc = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
])

train_dataset = torchvision.datasets.FashionMNIST('./fashion_mnist', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST('./fashion_mnist', train=False, download=True,transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, num_workers=4, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch, num_workers=4, shuffle=True)

In [None]:
train_dataset[0][0].shape

In [None]:
def train(model, optimizer, criteria, dataloader, epochs):
    for e in range(epochs):
        desc = f'Epoch: {e+1}/{epochs}'
        avg_loss = 0
        
        model.train()
        for imgs, labels in tqdm(dataloader, desc=desc): 
            imgs = imgs.to(device).repeat(1, 3, 1, 1)
            labels = labels.to(device)

            optimizer.zero_grad() 
            output = model(imgs)

            loss = criteria(output, labels)
            avg_loss += loss.item()

            loss.backward()
            optimizer.step()

        avg_loss /= len(train_dataloader)
        print('Loss', round(avg_loss, 3))

In [None]:
def eval(model, dataloader):
    avg_accuracy = 0
    avg_f1 = 0
    
    model.eval()
    for imgs, labels in tqdm(dataloader):
        imgs = imgs.to(device).repeat(1, 3, 1, 1)

        with torch.no_grad():
            out = model(imgs)

        out = torch.argmax(out, dim=1)
        avg_accuracy += accuracy_score(labels.numpy(), out.detach().numpy())
        avg_f1 += f1_score(labels.numpy(), out.detach().numpy(), average='macro')

    avg_accuracy /= len(test_dataloader)
    avg_f1 /= len(test_dataloader)

    print(f'Accuracy : {round(avg_accuracy, 3)}')
    print(f'F1 : {round(avg_f1, 3)}')

In [None]:
model = torchvision.models.mobilenet_v2(num_classes=nc).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criteria = torch.nn.CrossEntropyLoss()
    
train(model, optimizer, criteria, train_dataloader, epochs)
eval(model, test_dataloader)

In [None]:
i = random.randint(0, len(test_dataset))
print('Item', i)

img, label = test_dataset[i]
out = model(img.unsqueeze(0).to(device))
out = torch.argmax(out, dim=1)

print(f'Predict : {out.item()}. Real : {label}')

torchvision.transforms.ToPILImage()(img)