In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import mnist_config
from model import ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

train_dataset = datasets.MNIST(root='../datasets', train=True, transform=transform)
val_dataset = datasets.MNIST(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = mnist_config()
model = build_model(ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'mnist.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 938/938 [01:39<00:00,  9.44Batch/s, loss=3.855174, top1=10.24%, top5=52.10%]
validation Epoch: 1: 100%|██████████| 157/157 [00:05<00:00, 30.27Batch/s, loss=7.109620, top1=10.84%, top5=53.26%]






train Epoch: 2: 100%|██████████| 938/938 [01:31<00:00, 10.20Batch/s, loss=0.418589, top1=24.77%, top5=72.72%]
validation Epoch: 2: 100%|██████████| 157/157 [00:05<00:00, 29.82Batch/s, loss=0.138408, top1=31.89%, top5=86.01%]






train Epoch: 3: 100%|██████████| 938/938 [01:31<00:00, 10.29Batch/s, loss=0.070869, top1=34.84%, top5=84.19%]
validation Epoch: 3: 100%|██████████| 157/157 [00:05<00:00, 29.77Batch/s, loss=0.073229, top1=58.06%, top5=96.32%]






train Epoch: 4: 100%|██████████| 938/938 [01:31<00:00, 10.25Batch/s, loss=0.044581, top1=50.86%, top5=94.09%]
validation Epoch: 4: 100%|██████████| 157/157 [00:05<00:00, 29.71Batch/s, loss=0.067861, top1=60.14%, top5=97.10%]






train Epoch: 5: 100%|██████████| 938/938 [01:31<00:00, 10.29Batch/s, loss=0.032961, top1=63.64%, top5=97.17%]
validation Epoch: 5: 100%|██████████| 157/157 [00:05<00:00, 29.95Batch/s, loss=0.042689, top1=75.86%, top5=98.91%]






train Epoch: 6: 100%|██████████| 938/938 [01:31<00:00, 10.22Batch/s, loss=0.023611, top1=74.54%, top5=98.51%]
validation Epoch: 6: 100%|██████████| 157/157 [00:05<00:00, 29.76Batch/s, loss=0.029268, top1=83.78%, top5=99.42%]






train Epoch: 7: 100%|██████████| 938/938 [01:31<00:00, 10.20Batch/s, loss=0.018757, top1=80.28%, top5=98.99%]
validation Epoch: 7: 100%|██████████| 157/157 [00:05<00:00, 30.04Batch/s, loss=0.019869, top1=89.89%, top5=99.64%]






train Epoch: 8: 100%|██████████| 938/938 [01:30<00:00, 10.33Batch/s, loss=0.015670, top1=83.81%, top5=99.27%]
validation Epoch: 8: 100%|██████████| 157/157 [00:05<00:00, 27.44Batch/s, loss=0.021660, top1=88.92%, top5=99.50%]






train Epoch: 9: 100%|██████████| 938/938 [01:31<00:00, 10.29Batch/s, loss=0.013186, top1=86.65%, top5=99.48%]
validation Epoch: 9: 100%|██████████| 157/157 [00:05<00:00, 29.90Batch/s, loss=0.015664, top1=92.28%, top5=99.73%]






train Epoch: 10: 100%|██████████| 938/938 [01:30<00:00, 10.38Batch/s, loss=0.011144, top1=88.64%, top5=99.58%]
validation Epoch: 10: 100%|██████████| 157/157 [00:05<00:00, 30.16Batch/s, loss=0.016610, top1=91.45%, top5=99.71%]






train Epoch: 11: 100%|██████████| 938/938 [01:30<00:00, 10.41Batch/s, loss=0.009795, top1=90.08%, top5=99.71%]
validation Epoch: 11: 100%|██████████| 157/157 [00:05<00:00, 30.52Batch/s, loss=0.011803, top1=94.03%, top5=99.79%]






train Epoch: 12: 100%|██████████| 938/938 [01:30<00:00, 10.35Batch/s, loss=0.008203, top1=91.63%, top5=99.75%]
validation Epoch: 12: 100%|██████████| 157/157 [00:05<00:00, 29.88Batch/s, loss=0.011505, top1=94.46%, top5=99.84%]






train Epoch: 13: 100%|██████████| 938/938 [01:30<00:00, 10.36Batch/s, loss=0.006906, top1=92.99%, top5=99.81%]
validation Epoch: 13: 100%|██████████| 157/157 [00:05<00:00, 29.87Batch/s, loss=0.009659, top1=95.30%, top5=99.87%]






train Epoch: 14: 100%|██████████| 938/938 [01:31<00:00, 10.27Batch/s, loss=0.005664, top1=94.22%, top5=99.87%]
validation Epoch: 14: 100%|██████████| 157/157 [00:05<00:00, 30.19Batch/s, loss=0.009829, top1=95.18%, top5=99.82%]






train Epoch: 15: 100%|██████████| 938/938 [01:30<00:00, 10.40Batch/s, loss=0.004863, top1=95.07%, top5=99.91%]
validation Epoch: 15: 100%|██████████| 157/157 [00:05<00:00, 30.84Batch/s, loss=0.008379, top1=96.07%, top5=99.89%]






train Epoch: 16: 100%|██████████| 938/938 [01:31<00:00, 10.20Batch/s, loss=0.004096, top1=95.76%, top5=99.92%]
validation Epoch: 16: 100%|██████████| 157/157 [00:05<00:00, 30.37Batch/s, loss=0.007145, top1=96.65%, top5=99.93%]






train Epoch: 17: 100%|██████████| 938/938 [01:37<00:00,  9.59Batch/s, loss=0.003372, top1=96.48%, top5=99.94%]
validation Epoch: 17: 100%|██████████| 157/157 [00:05<00:00, 27.38Batch/s, loss=0.006713, top1=96.76%, top5=99.91%]






train Epoch: 18: 100%|██████████| 938/938 [01:44<00:00,  8.96Batch/s, loss=0.002779, top1=97.06%, top5=99.97%]
validation Epoch: 18: 100%|██████████| 157/157 [00:05<00:00, 28.02Batch/s, loss=0.005788, top1=97.28%, top5=99.94%]






train Epoch: 19: 100%|██████████| 938/938 [01:36<00:00,  9.74Batch/s, loss=0.002415, top1=97.39%, top5=99.98%] 
validation Epoch: 19: 100%|██████████| 157/157 [00:05<00:00, 30.83Batch/s, loss=0.005331, top1=97.38%, top5=99.94%]






train Epoch: 20: 100%|██████████| 938/938 [01:32<00:00, 10.15Batch/s, loss=0.002221, top1=97.66%, top5=99.99%]
validation Epoch: 20: 100%|██████████| 157/157 [00:05<00:00, 29.86Batch/s, loss=0.005229, top1=97.47%, top5=99.94%]






