In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import base_config
from model import T2T_ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 16
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [5]:
model_config = base_config()
model_config.update({'num_classes': 10})
model = build_model(T2T_ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'pretrained/cifar10_224.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 3125/3125 [10:02<00:00,  5.19Batch/s, loss=2.557824, top1=9.36%, top5=48.94%]
validation Epoch: 1: 100%|██████████| 625/625 [00:42<00:00, 14.71Batch/s, loss=2.783238, top1=8.11%, top5=48.12%]






train Epoch: 2: 100%|██████████| 3125/3125 [10:01<00:00,  5.19Batch/s, loss=2.292878, top1=20.64%, top5=69.61%]
validation Epoch: 2: 100%|██████████| 625/625 [00:41<00:00, 14.95Batch/s, loss=0.129469, top1=29.95%, top5=78.50%]






train Epoch: 3: 100%|██████████| 3125/3125 [09:57<00:00,  5.23Batch/s, loss=0.138167, top1=24.90%, top5=77.29%]
validation Epoch: 3: 100%|██████████| 625/625 [00:43<00:00, 14.49Batch/s, loss=0.123360, top1=27.36%, top5=83.54%]






train Epoch: 4: 100%|██████████| 3125/3125 [09:54<00:00,  5.25Batch/s, loss=243.161665, top1=13.94%, top5=58.18%]
validation Epoch: 4: 100%|██████████| 625/625 [00:41<00:00, 14.97Batch/s, loss=0.956823, top1=14.97%, top5=57.78%]






train Epoch: 5: 100%|██████████| 3125/3125 [09:51<00:00,  5.28Batch/s, loss=0.683481, top1=15.45%, top5=59.70%]
validation Epoch: 5: 100%|██████████| 625/625 [00:40<00:00, 15.28Batch/s, loss=0.246628, top1=14.93%, top5=61.63%]






train Epoch: 6: 100%|██████████| 3125/3125 [09:49<00:00,  5.30Batch/s, loss=170.570733, top1=19.10%, top5=66.77%]
validation Epoch: 6: 100%|██████████| 625/625 [00:41<00:00, 15.03Batch/s, loss=1.373369, top1=22.86%, top5=69.94%]






train Epoch: 7: 100%|██████████| 3125/3125 [09:53<00:00,  5.26Batch/s, loss=0.856761, top1=20.94%, top5=70.43%]
validation Epoch: 7: 100%|██████████| 625/625 [00:41<00:00, 15.23Batch/s, loss=0.292182, top1=22.26%, top5=71.69%]






train Epoch: 8: 100%|██████████| 3125/3125 [09:54<00:00,  5.25Batch/s, loss=0.277133, top1=20.38%, top5=70.10%]
validation Epoch: 8: 100%|██████████| 625/625 [00:42<00:00, 14.67Batch/s, loss=0.224168, top1=21.11%, top5=66.05%]






train Epoch: 9: 100%|██████████| 3125/3125 [09:51<00:00,  5.29Batch/s, loss=74.257877, top1=18.91%, top5=68.12%] 
validation Epoch: 9: 100%|██████████| 625/625 [00:40<00:00, 15.25Batch/s, loss=1.359918, top1=21.10%, top5=68.34%]






train Epoch: 10: 100%|██████████| 3125/3125 [09:49<00:00,  5.30Batch/s, loss=0.671820, top1=17.00%, top5=65.80%]
validation Epoch: 10: 100%|██████████| 625/625 [00:41<00:00, 14.95Batch/s, loss=0.308222, top1=20.99%, top5=70.43%]






train Epoch: 11: 100%|██████████| 3125/3125 [09:55<00:00,  5.25Batch/s, loss=0.223888, top1=17.79%, top5=66.25%]
validation Epoch: 11: 100%|██████████| 625/625 [00:41<00:00, 14.96Batch/s, loss=0.157791, top1=21.70%, top5=70.83%]






train Epoch: 12: 100%|██████████| 3125/3125 [09:54<00:00,  5.25Batch/s, loss=5.244358, top1=17.95%, top5=67.16%]
validation Epoch: 12: 100%|██████████| 625/625 [00:41<00:00, 14.92Batch/s, loss=0.211208, top1=24.36%, top5=77.96%]






train Epoch: 13: 100%|██████████| 3125/3125 [09:53<00:00,  5.26Batch/s, loss=0.206504, top1=20.22%, top5=70.35%]
validation Epoch: 13: 100%|██████████| 625/625 [00:41<00:00, 14.93Batch/s, loss=0.128721, top1=25.13%, top5=78.68%]






train Epoch: 14: 100%|██████████| 3125/3125 [09:54<00:00,  5.26Batch/s, loss=0.148343, top1=21.84%, top5=73.81%]
validation Epoch: 14: 100%|██████████| 625/625 [00:41<00:00, 14.94Batch/s, loss=0.126898, top1=25.88%, top5=76.26%]






train Epoch: 15: 100%|██████████| 3125/3125 [09:55<00:00,  5.25Batch/s, loss=0.140731, top1=23.09%, top5=75.80%]
validation Epoch: 15: 100%|██████████| 625/625 [00:41<00:00, 14.94Batch/s, loss=0.124482, top1=26.87%, top5=79.38%]






train Epoch: 16: 100%|██████████| 3125/3125 [09:54<00:00,  5.25Batch/s, loss=0.140983, top1=23.77%, top5=76.91%]
validation Epoch: 16: 100%|██████████| 625/625 [00:42<00:00, 14.62Batch/s, loss=0.118874, top1=28.16%, top5=82.51%]






train Epoch: 17:  96%|█████████▋| 3014/3125 [09:34<00:21,  5.21Batch/s, loss=0.121873, top1=27.97%, top5=81.37%]