In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import cifar10_config
from model import ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = cifar10_config()
model = build_model(ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'pretrained/cifar10.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 782/782 [00:27<00:00, 28.14Batch/s, loss=12.131070, top1=10.31%, top5=49.01%]
validation Epoch: 1: 100%|██████████| 157/157 [00:02<00:00, 59.60Batch/s, loss=13.643114, top1=10.14%, top5=49.45%]






train Epoch: 2: 100%|██████████| 782/782 [00:29<00:00, 26.81Batch/s, loss=1.777609, top1=15.60%, top5=61.33%]
validation Epoch: 2: 100%|██████████| 157/157 [00:02<00:00, 58.46Batch/s, loss=0.309538, top1=22.44%, top5=71.61%]






train Epoch: 3: 100%|██████████| 782/782 [00:29<00:00, 26.89Batch/s, loss=0.276000, top1=17.89%, top5=66.64%]
validation Epoch: 3: 100%|██████████| 157/157 [00:02<00:00, 58.59Batch/s, loss=0.146424, top1=23.87%, top5=74.56%]






train Epoch: 4: 100%|██████████| 782/782 [00:28<00:00, 27.63Batch/s, loss=0.164749, top1=20.72%, top5=71.64%]
validation Epoch: 4: 100%|██████████| 157/157 [00:02<00:00, 65.96Batch/s, loss=0.140339, top1=25.71%, top5=78.63%]






train Epoch: 5: 100%|██████████| 782/782 [00:27<00:00, 28.56Batch/s, loss=0.138626, top1=23.88%, top5=76.45%]
validation Epoch: 5: 100%|██████████| 157/157 [00:02<00:00, 62.79Batch/s, loss=0.123583, top1=28.14%, top5=79.87%]






train Epoch: 6: 100%|██████████| 782/782 [00:26<00:00, 30.00Batch/s, loss=0.135023, top1=25.61%, top5=78.91%]
validation Epoch: 6: 100%|██████████| 157/157 [00:02<00:00, 66.78Batch/s, loss=0.118324, top1=30.99%, top5=84.38%]






train Epoch: 7: 100%|██████████| 782/782 [00:26<00:00, 30.04Batch/s, loss=0.122149, top1=29.55%, top5=82.42%]
validation Epoch: 7: 100%|██████████| 157/157 [00:02<00:00, 64.36Batch/s, loss=0.115094, top1=32.57%, top5=85.25%]






train Epoch: 8: 100%|██████████| 782/782 [00:25<00:00, 30.16Batch/s, loss=0.121431, top1=30.28%, top5=83.33%]
validation Epoch: 8: 100%|██████████| 157/157 [00:02<00:00, 68.24Batch/s, loss=0.113139, top1=35.18%, top5=85.31%]






train Epoch: 9: 100%|██████████| 782/782 [00:26<00:00, 29.57Batch/s, loss=0.113101, top1=33.97%, top5=85.73%]
validation Epoch: 9: 100%|██████████| 157/157 [00:02<00:00, 64.37Batch/s, loss=0.104506, top1=37.91%, top5=88.80%]






train Epoch: 10: 100%|██████████| 782/782 [00:26<00:00, 29.49Batch/s, loss=0.109588, top1=36.03%, top5=87.09%]
validation Epoch: 10: 100%|██████████| 157/157 [00:02<00:00, 65.85Batch/s, loss=0.103108, top1=39.96%, top5=89.15%]






train Epoch: 11: 100%|██████████| 782/782 [00:26<00:00, 29.02Batch/s, loss=0.113757, top1=36.36%, top5=86.96%]
validation Epoch: 11: 100%|██████████| 157/157 [00:02<00:00, 66.10Batch/s, loss=0.107495, top1=37.99%, top5=88.43%]






train Epoch: 12: 100%|██████████| 782/782 [00:26<00:00, 30.07Batch/s, loss=0.102813, top1=40.06%, top5=89.07%]
validation Epoch: 12: 100%|██████████| 157/157 [00:02<00:00, 67.47Batch/s, loss=0.100002, top1=40.94%, top5=90.90%]






train Epoch: 13: 100%|██████████| 782/782 [00:26<00:00, 29.70Batch/s, loss=0.098743, top1=42.22%, top5=90.19%]
validation Epoch: 13: 100%|██████████| 157/157 [00:02<00:00, 65.09Batch/s, loss=0.094414, top1=45.43%, top5=91.59%]






train Epoch: 14: 100%|██████████| 782/782 [00:26<00:00, 29.62Batch/s, loss=0.095013, top1=44.77%, top5=91.10%]
validation Epoch: 14: 100%|██████████| 157/157 [00:02<00:00, 65.97Batch/s, loss=0.090790, top1=47.34%, top5=92.27%]






train Epoch: 15: 100%|██████████| 782/782 [00:26<00:00, 29.08Batch/s, loss=0.091193, top1=46.82%, top5=92.04%]
validation Epoch: 15: 100%|██████████| 157/157 [00:02<00:00, 62.42Batch/s, loss=0.088553, top1=48.77%, top5=93.22%]






train Epoch: 16: 100%|██████████| 782/782 [00:27<00:00, 28.89Batch/s, loss=0.087658, top1=48.95%, top5=92.89%]
validation Epoch: 16: 100%|██████████| 157/157 [00:02<00:00, 64.71Batch/s, loss=0.084790, top1=50.62%, top5=93.72%]






train Epoch: 17: 100%|██████████| 782/782 [00:27<00:00, 28.50Batch/s, loss=0.084262, top1=51.18%, top5=93.49%]
validation Epoch: 17: 100%|██████████| 157/157 [00:02<00:00, 61.13Batch/s, loss=0.082263, top1=51.97%, top5=94.04%]






train Epoch: 18: 100%|██████████| 782/782 [00:26<00:00, 29.01Batch/s, loss=0.081422, top1=52.70%, top5=93.99%]
validation Epoch: 18: 100%|██████████| 157/157 [00:02<00:00, 68.07Batch/s, loss=0.080995, top1=52.99%, top5=94.42%]






train Epoch: 19: 100%|██████████| 782/782 [00:25<00:00, 30.82Batch/s, loss=0.079762, top1=53.83%, top5=94.43%]
validation Epoch: 19: 100%|██████████| 157/157 [00:02<00:00, 67.99Batch/s, loss=0.079668, top1=54.29%, top5=94.44%]






train Epoch: 20: 100%|██████████| 782/782 [00:25<00:00, 31.14Batch/s, loss=0.078313, top1=54.52%, top5=94.66%]
validation Epoch: 20: 100%|██████████| 157/157 [00:02<00:00, 68.83Batch/s, loss=0.079280, top1=54.49%, top5=94.45%]






