In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import mnist_config
from model import ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

train_dataset = datasets.MNIST(root='../datasets', train=True, transform=transform)
val_dataset = datasets.MNIST(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = mnist_config()
model = build_model(ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'pretrained/mnist.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 938/938 [00:30<00:00, 30.64Batch/s, loss=6.023933, top1=11.47%, top5=49.95%]
validation Epoch: 1: 100%|██████████| 157/157 [00:02<00:00, 78.36Batch/s, loss=12.052503, top1=11.17%, top5=49.73%]






train Epoch: 2: 100%|██████████| 938/938 [00:29<00:00, 31.59Batch/s, loss=0.512244, top1=23.11%, top5=69.78%]
validation Epoch: 2: 100%|██████████| 157/157 [00:01<00:00, 81.10Batch/s, loss=0.126851, top1=38.09%, top5=90.56%]






train Epoch: 3: 100%|██████████| 938/938 [00:28<00:00, 32.95Batch/s, loss=0.085171, top1=32.54%, top5=82.55%]
validation Epoch: 3: 100%|██████████| 157/157 [00:01<00:00, 81.73Batch/s, loss=0.079065, top1=53.83%, top5=95.20%]






train Epoch: 4: 100%|██████████| 938/938 [00:28<00:00, 32.74Batch/s, loss=0.047338, top1=48.90%, top5=93.36%]
validation Epoch: 4: 100%|██████████| 157/157 [00:02<00:00, 77.59Batch/s, loss=0.054267, top1=68.33%, top5=98.37%]






train Epoch: 5: 100%|██████████| 938/938 [00:29<00:00, 31.94Batch/s, loss=0.032758, top1=63.91%, top5=97.34%]
validation Epoch: 5: 100%|██████████| 157/157 [00:02<00:00, 78.43Batch/s, loss=0.036202, top1=80.50%, top5=99.20%]






train Epoch: 6: 100%|██████████| 938/938 [00:29<00:00, 32.05Batch/s, loss=0.023088, top1=75.01%, top5=98.73%]
validation Epoch: 6: 100%|██████████| 157/157 [00:01<00:00, 78.53Batch/s, loss=0.024910, top1=86.58%, top5=99.42%]






train Epoch: 7: 100%|██████████| 938/938 [00:29<00:00, 31.94Batch/s, loss=0.018993, top1=80.17%, top5=99.01%]
validation Epoch: 7: 100%|██████████| 157/157 [00:01<00:00, 80.09Batch/s, loss=0.025400, top1=87.35%, top5=99.48%]






train Epoch: 8: 100%|██████████| 938/938 [00:29<00:00, 32.13Batch/s, loss=0.015387, top1=84.12%, top5=99.32%]
validation Epoch: 8: 100%|██████████| 157/157 [00:01<00:00, 78.52Batch/s, loss=0.017745, top1=90.84%, top5=99.66%]






train Epoch: 9: 100%|██████████| 938/938 [00:29<00:00, 31.99Batch/s, loss=0.012963, top1=86.80%, top5=99.47%]
validation Epoch: 9: 100%|██████████| 157/157 [00:01<00:00, 78.70Batch/s, loss=0.020843, top1=89.77%, top5=99.62%]






train Epoch: 10: 100%|██████████| 938/938 [00:29<00:00, 32.10Batch/s, loss=0.011146, top1=88.80%, top5=99.59%]
validation Epoch: 10: 100%|██████████| 157/157 [00:02<00:00, 78.23Batch/s, loss=0.017775, top1=91.28%, top5=99.69%]






train Epoch: 11: 100%|██████████| 938/938 [00:29<00:00, 31.71Batch/s, loss=0.009525, top1=90.45%, top5=99.69%]
validation Epoch: 11: 100%|██████████| 157/157 [00:01<00:00, 79.94Batch/s, loss=0.013373, top1=93.48%, top5=99.74%]






train Epoch: 12: 100%|██████████| 938/938 [00:29<00:00, 31.73Batch/s, loss=0.007984, top1=91.98%, top5=99.77%]
validation Epoch: 12: 100%|██████████| 157/157 [00:02<00:00, 76.67Batch/s, loss=0.012360, top1=93.73%, top5=99.83%]






train Epoch: 13: 100%|██████████| 938/938 [00:29<00:00, 31.64Batch/s, loss=0.006957, top1=93.05%, top5=99.80%]
validation Epoch: 13: 100%|██████████| 157/157 [00:02<00:00, 75.01Batch/s, loss=0.010246, top1=94.78%, top5=99.88%]






train Epoch: 14: 100%|██████████| 938/938 [00:32<00:00, 28.87Batch/s, loss=0.005799, top1=94.19%, top5=99.88%]
validation Epoch: 14: 100%|██████████| 157/157 [00:02<00:00, 68.66Batch/s, loss=0.009229, top1=95.45%, top5=99.85%]






train Epoch: 15: 100%|██████████| 938/938 [00:32<00:00, 28.95Batch/s, loss=0.004856, top1=95.01%, top5=99.88%]
validation Epoch: 15: 100%|██████████| 157/157 [00:02<00:00, 67.63Batch/s, loss=0.007627, top1=96.13%, top5=99.92%]






train Epoch: 16: 100%|██████████| 938/938 [00:32<00:00, 29.07Batch/s, loss=0.004025, top1=95.81%, top5=99.91%]
validation Epoch: 16: 100%|██████████| 157/157 [00:02<00:00, 69.23Batch/s, loss=0.007265, top1=96.47%, top5=99.93%]






train Epoch: 17: 100%|██████████| 938/938 [00:32<00:00, 28.76Batch/s, loss=0.003402, top1=96.46%, top5=99.93%]
validation Epoch: 17: 100%|██████████| 157/157 [00:02<00:00, 68.46Batch/s, loss=0.005981, top1=97.10%, top5=99.93%]






train Epoch: 18: 100%|██████████| 938/938 [00:31<00:00, 29.47Batch/s, loss=0.002861, top1=96.99%, top5=99.97%]
validation Epoch: 18: 100%|██████████| 157/157 [00:02<00:00, 69.42Batch/s, loss=0.005784, top1=97.22%, top5=99.94%]






train Epoch: 19: 100%|██████████| 938/938 [00:31<00:00, 29.35Batch/s, loss=0.002492, top1=97.31%, top5=99.97%]
validation Epoch: 19: 100%|██████████| 157/157 [00:02<00:00, 69.84Batch/s, loss=0.005484, top1=97.32%, top5=99.94%]






train Epoch: 20: 100%|██████████| 938/938 [00:32<00:00, 29.22Batch/s, loss=0.002273, top1=97.58%, top5=99.97%]
validation Epoch: 20: 100%|██████████| 157/157 [00:02<00:00, 68.20Batch/s, loss=0.005378, top1=97.32%, top5=99.95%]






