In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import b16_config
from model import ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 32
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = b16_config()
model_config.update({'num_classes':10})
model = build_model(ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'pretrained/cifar10_224.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 1563/1563 [05:49<00:00,  4.47Batch/s, loss=51.686659, top1=11.24%, top5=52.33%]
validation Epoch: 1: 100%|██████████| 313/313 [00:26<00:00, 11.79Batch/s, loss=57.020582, top1=11.78%, top5=51.69%]






train Epoch: 2: 100%|██████████| 1563/1563 [05:44<00:00,  4.54Batch/s, loss=233.141750, top1=16.09%, top5=62.53%]
validation Epoch: 2: 100%|██████████| 313/313 [00:27<00:00, 11.43Batch/s, loss=4.434375, top1=17.94%, top5=68.37%]






train Epoch: 3: 100%|██████████| 1563/1563 [05:47<00:00,  4.49Batch/s, loss=1229.258404, top1=16.52%, top5=63.02%]
validation Epoch: 3: 100%|██████████| 313/313 [00:26<00:00, 11.67Batch/s, loss=7.564097, top1=20.91%, top5=70.12%]






train Epoch: 4: 100%|██████████| 1563/1563 [05:50<00:00,  4.45Batch/s, loss=2209.669599, top1=17.14%, top5=63.44%]
validation Epoch: 4: 100%|██████████| 313/313 [00:27<00:00, 11.53Batch/s, loss=24.505027, top1=17.50%, top5=63.80%]






train Epoch: 5: 100%|██████████| 1563/1563 [05:53<00:00,  4.42Batch/s, loss=14.057748, top1=18.09%, top5=66.04%]
validation Epoch: 5: 100%|██████████| 313/313 [00:27<00:00, 11.31Batch/s, loss=5.040315, top1=17.82%, top5=68.38%]






train Epoch: 6: 100%|██████████| 1563/1563 [06:01<00:00,  4.32Batch/s, loss=3.273657, top1=18.03%, top5=66.02%]
validation Epoch: 6: 100%|██████████| 313/313 [00:27<00:00, 11.31Batch/s, loss=2.131788, top1=17.86%, top5=66.22%]






train Epoch: 7: 100%|██████████| 1563/1563 [05:57<00:00,  4.37Batch/s, loss=27859.990741, top1=17.48%, top5=64.41%]
validation Epoch: 7: 100%|██████████| 313/313 [00:28<00:00, 11.18Batch/s, loss=885.275010, top1=20.95%, top5=68.36%]






train Epoch: 8: 100%|██████████| 1563/1563 [05:58<00:00,  4.36Batch/s, loss=348.645903, top1=19.12%, top5=68.35%]
validation Epoch: 8: 100%|██████████| 313/313 [00:28<00:00, 11.06Batch/s, loss=9611.929273, top1=9.80%, top5=56.03%] 






train Epoch: 9: 100%|██████████| 1563/1563 [05:53<00:00,  4.42Batch/s, loss=630.789184, top1=17.42%, top5=65.18%]
validation Epoch: 9: 100%|██████████| 313/313 [00:27<00:00, 11.54Batch/s, loss=75.625204, top1=20.37%, top5=68.29%]






train Epoch: 10: 100%|██████████| 1563/1563 [05:55<00:00,  4.40Batch/s, loss=80.913816, top1=17.90%, top5=65.38%]
validation Epoch: 10: 100%|██████████| 313/313 [00:28<00:00, 11.01Batch/s, loss=29.668176, top1=16.01%, top5=69.06%]






train Epoch: 11: 100%|██████████| 1563/1563 [06:43<00:00,  3.87Batch/s, loss=4232.542893, top1=16.93%, top5=63.44%]
validation Epoch: 11: 100%|██████████| 313/313 [00:28<00:00, 10.93Batch/s, loss=490.847967, top1=17.36%, top5=60.59%]






train Epoch: 12: 100%|██████████| 1563/1563 [06:41<00:00,  3.90Batch/s, loss=231.117285, top1=16.69%, top5=63.60%] 
validation Epoch: 12: 100%|██████████| 313/313 [00:29<00:00, 10.77Batch/s, loss=54.812037, top1=19.11%, top5=69.71%]






train Epoch: 13: 100%|██████████| 1563/1563 [07:14<00:00,  3.59Batch/s, loss=159.230325, top1=16.64%, top5=63.71%]
validation Epoch: 13: 100%|██████████| 313/313 [00:32<00:00,  9.64Batch/s, loss=85.173231, top1=20.21%, top5=63.88%]






train Epoch: 14: 100%|██████████| 1563/1563 [06:45<00:00,  3.86Batch/s, loss=48.804994, top1=17.54%, top5=65.40%]
validation Epoch: 14: 100%|██████████| 313/313 [00:30<00:00, 10.22Batch/s, loss=24.145651, top1=20.67%, top5=70.34%]






train Epoch: 15: 100%|██████████| 1563/1563 [06:55<00:00,  3.77Batch/s, loss=37.141108, top1=17.88%, top5=65.28%] 
validation Epoch: 15: 100%|██████████| 313/313 [00:31<00:00,  9.94Batch/s, loss=19.235947, top1=15.84%, top5=65.04%]






train Epoch: 16: 100%|██████████| 1563/1563 [07:08<00:00,  3.65Batch/s, loss=13.508015, top1=18.22%, top5=66.00%]
validation Epoch: 16: 100%|██████████| 313/313 [00:32<00:00,  9.78Batch/s, loss=11.642723, top1=14.70%, top5=59.68%]






train Epoch: 17: 100%|██████████| 1563/1563 [07:09<00:00,  3.64Batch/s, loss=9.097540, top1=18.42%, top5=66.07%]
validation Epoch: 17: 100%|██████████| 313/313 [00:33<00:00,  9.31Batch/s, loss=364.442802, top1=9.98%, top5=50.51%] 






train Epoch: 18: 100%|██████████| 1563/1563 [06:44<00:00,  3.87Batch/s, loss=6.383918, top1=18.24%, top5=66.73%]
validation Epoch: 18: 100%|██████████| 313/313 [00:33<00:00,  9.35Batch/s, loss=3.695233, top1=15.53%, top5=66.32%]






train Epoch: 19: 100%|██████████| 1563/1563 [06:44<00:00,  3.87Batch/s, loss=1.782707, top1=18.97%, top5=68.35%]
validation Epoch: 19: 100%|██████████| 313/313 [00:31<00:00, 10.03Batch/s, loss=1.928793, top1=16.53%, top5=62.42%]






train Epoch: 20: 100%|██████████| 1563/1563 [06:58<00:00,  3.74Batch/s, loss=1.037173, top1=19.63%, top5=69.54%]
validation Epoch: 20: 100%|██████████| 313/313 [00:30<00:00, 10.17Batch/s, loss=1.888574, top1=13.74%, top5=67.04%]




