In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import mnist_config
from model import ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

train_dataset = datasets.MNIST(root='../datasets', train=True, transform=transform)
val_dataset = datasets.MNIST(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = mnist_config()
model = build_model(ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'mnist.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 938/938 [00:30<00:00, 30.83Batch/s, loss=4.866860, top1=11.47%, top5=49.27%]
validation Epoch: 1: 100%|██████████| 157/157 [00:02<00:00, 76.93Batch/s, loss=3.627524, top1=13.71%, top5=47.57%]






train Epoch: 2: 100%|██████████| 938/938 [00:28<00:00, 32.36Batch/s, loss=0.201183, top1=72.01%, top5=95.86%]
validation Epoch: 2: 100%|██████████| 157/157 [00:02<00:00, 77.41Batch/s, loss=0.160860, top1=89.34%, top5=99.43%]






train Epoch: 3: 100%|██████████| 938/938 [00:29<00:00, 32.26Batch/s, loss=0.176811, top1=87.50%, top5=99.40%]
validation Epoch: 3: 100%|██████████| 157/157 [00:02<00:00, 76.59Batch/s, loss=0.269000, top1=92.57%, top5=99.59%]






train Epoch: 4: 100%|██████████| 938/938 [00:28<00:00, 32.45Batch/s, loss=0.444269, top1=90.78%, top5=99.65%]
validation Epoch: 4: 100%|██████████| 157/157 [00:01<00:00, 79.22Batch/s, loss=0.002629, top1=94.07%, top5=99.78%]






train Epoch: 5: 100%|██████████| 938/938 [00:28<00:00, 32.59Batch/s, loss=0.088275, top1=92.42%, top5=99.73%]
validation Epoch: 5: 100%|██████████| 157/157 [00:01<00:00, 80.43Batch/s, loss=0.014029, top1=94.89%, top5=99.88%]






train Epoch: 6: 100%|██████████| 938/938 [00:28<00:00, 32.74Batch/s, loss=0.046392, top1=94.03%, top5=99.81%]
validation Epoch: 6: 100%|██████████| 157/157 [00:01<00:00, 79.37Batch/s, loss=0.203331, top1=94.39%, top5=99.79%]






train Epoch: 7: 100%|██████████| 938/938 [00:28<00:00, 32.56Batch/s, loss=0.102743, top1=94.90%, top5=99.86%]
validation Epoch: 7: 100%|██████████| 157/157 [00:02<00:00, 77.83Batch/s, loss=0.112094, top1=95.99%, top5=99.89%]






train Epoch: 8: 100%|██████████| 938/938 [00:28<00:00, 32.43Batch/s, loss=0.221623, top1=95.64%, top5=99.90%]
validation Epoch: 8: 100%|██████████| 157/157 [00:01<00:00, 79.28Batch/s, loss=0.011450, top1=96.54%, top5=99.91%]






train Epoch: 9: 100%|██████████| 938/938 [00:28<00:00, 32.56Batch/s, loss=0.050441, top1=96.25%, top5=99.93%]
validation Epoch: 9: 100%|██████████| 157/157 [00:01<00:00, 79.17Batch/s, loss=0.017448, top1=95.89%, top5=99.93%]






train Epoch: 10: 100%|██████████| 938/938 [00:28<00:00, 32.65Batch/s, loss=0.109524, top1=96.63%, top5=99.94%]
validation Epoch: 10: 100%|██████████| 157/157 [00:01<00:00, 79.60Batch/s, loss=0.042526, top1=97.10%, top5=99.95%]






train Epoch: 11: 100%|██████████| 938/938 [00:28<00:00, 32.60Batch/s, loss=0.101646, top1=97.11%, top5=99.96%]
validation Epoch: 11: 100%|██████████| 157/157 [00:01<00:00, 79.15Batch/s, loss=0.369809, top1=97.06%, top5=99.93%]






train Epoch: 12: 100%|██████████| 938/938 [00:29<00:00, 32.05Batch/s, loss=0.261646, top1=97.51%, top5=99.97%]
validation Epoch: 12: 100%|██████████| 157/157 [00:01<00:00, 78.85Batch/s, loss=0.011615, top1=97.33%, top5=99.95%]






train Epoch: 13: 100%|██████████| 938/938 [00:28<00:00, 32.35Batch/s, loss=0.055056, top1=97.99%, top5=99.97%]
validation Epoch: 13: 100%|██████████| 157/157 [00:01<00:00, 80.93Batch/s, loss=0.692583, top1=97.47%, top5=99.91%]






train Epoch: 14: 100%|██████████| 938/938 [00:28<00:00, 33.34Batch/s, loss=0.008043, top1=98.42%, top5=99.99%] 
validation Epoch: 14: 100%|██████████| 157/157 [00:01<00:00, 80.39Batch/s, loss=0.442880, top1=97.80%, top5=99.96%]






train Epoch: 15: 100%|██████████| 938/938 [00:28<00:00, 32.78Batch/s, loss=0.004279, top1=98.84%, top5=100.00%]
validation Epoch: 15: 100%|██████████| 157/157 [00:01<00:00, 78.81Batch/s, loss=0.005285, top1=97.69%, top5=99.96%]






train Epoch: 16: 100%|██████████| 938/938 [00:28<00:00, 32.93Batch/s, loss=0.008796, top1=99.17%, top5=100.00%]
validation Epoch: 16: 100%|██████████| 157/157 [00:01<00:00, 80.36Batch/s, loss=0.042066, top1=98.16%, top5=99.98%]






train Epoch: 17: 100%|██████████| 938/938 [00:28<00:00, 33.44Batch/s, loss=0.024984, top1=99.50%, top5=100.00%]
validation Epoch: 17: 100%|██████████| 157/157 [00:01<00:00, 79.57Batch/s, loss=0.000354, top1=98.29%, top5=99.94%]






train Epoch: 18: 100%|██████████| 938/938 [00:28<00:00, 32.64Batch/s, loss=0.000611, top1=99.67%, top5=100.00%]
validation Epoch: 18: 100%|██████████| 157/157 [00:01<00:00, 80.39Batch/s, loss=0.000009, top1=98.33%, top5=99.98%]






train Epoch: 19: 100%|██████████| 938/938 [00:28<00:00, 32.37Batch/s, loss=0.000144, top1=99.79%, top5=100.00%]
validation Epoch: 19: 100%|██████████| 157/157 [00:01<00:00, 79.80Batch/s, loss=0.156243, top1=98.42%, top5=99.97%]






train Epoch: 20: 100%|██████████| 938/938 [00:28<00:00, 32.36Batch/s, loss=0.000197, top1=99.85%, top5=100.00%]
validation Epoch: 20: 100%|██████████| 157/157 [00:01<00:00, 79.10Batch/s, loss=0.001372, top1=98.49%, top5=99.97%]






