In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from configs import cifar10_config
from model import T2T_ViT
from utils import Trainer, WarmupCosineSchedule, WarmupLinearSchedule, build_model

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 20

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
model_config = cifar10_config()
model = build_model(T2T_ViT, model_config).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = WarmupCosineSchedule(optimizer, num_epochs//5, num_epochs)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, scheduler, num_epochs, (1, ), 'pretrained/cifar10.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 782/782 [01:01<00:00, 12.73Batch/s, loss=0.217143, top1=9.16%, top5=49.96%]
validation Epoch: 1: 100%|██████████| 157/157 [00:03<00:00, 40.87Batch/s, loss=0.212786, top1=9.17%, top5=50.18%]






train Epoch: 2: 100%|██████████| 782/782 [00:59<00:00, 13.08Batch/s, loss=0.111519, top1=36.51%, top5=85.44%]
validation Epoch: 2: 100%|██████████| 157/157 [00:03<00:00, 41.01Batch/s, loss=0.095891, top1=44.43%, top5=90.64%]






train Epoch: 3: 100%|██████████| 782/782 [00:59<00:00, 13.13Batch/s, loss=0.097234, top1=44.07%, top5=90.25%]
validation Epoch: 3: 100%|██████████| 157/157 [00:03<00:00, 40.93Batch/s, loss=0.093101, top1=47.47%, top5=91.64%]






train Epoch: 4: 100%|██████████| 782/782 [00:59<00:00, 13.21Batch/s, loss=0.091427, top1=47.74%, top5=91.95%]
validation Epoch: 4: 100%|██████████| 157/157 [00:03<00:00, 40.17Batch/s, loss=0.091377, top1=47.94%, top5=92.39%]






train Epoch: 5: 100%|██████████| 782/782 [01:00<00:00, 12.89Batch/s, loss=0.088000, top1=49.61%, top5=92.69%]
validation Epoch: 5: 100%|██████████| 157/157 [00:03<00:00, 40.98Batch/s, loss=0.086105, top1=51.18%, top5=93.01%]






train Epoch: 6: 100%|██████████| 782/782 [00:59<00:00, 13.15Batch/s, loss=0.081952, top1=53.22%, top5=93.87%]
validation Epoch: 6: 100%|██████████| 157/157 [00:03<00:00, 39.68Batch/s, loss=0.086006, top1=51.23%, top5=92.96%]






train Epoch: 7: 100%|██████████| 782/782 [00:59<00:00, 13.06Batch/s, loss=0.078123, top1=55.76%, top5=94.54%]
validation Epoch: 7: 100%|██████████| 157/157 [00:03<00:00, 41.54Batch/s, loss=0.084623, top1=52.59%, top5=93.53%]






train Epoch: 8: 100%|██████████| 782/782 [01:00<00:00, 13.02Batch/s, loss=0.074098, top1=57.82%, top5=95.55%]
validation Epoch: 8: 100%|██████████| 157/157 [00:03<00:00, 40.57Batch/s, loss=0.086486, top1=51.63%, top5=93.38%]






train Epoch: 9: 100%|██████████| 782/782 [01:00<00:00, 13.03Batch/s, loss=0.069753, top1=60.05%, top5=96.11%]
validation Epoch: 9: 100%|██████████| 157/157 [00:03<00:00, 40.47Batch/s, loss=0.084594, top1=53.41%, top5=93.74%]






train Epoch: 10: 100%|██████████| 782/782 [01:00<00:00, 12.92Batch/s, loss=0.063709, top1=63.27%, top5=96.95%]
validation Epoch: 10: 100%|██████████| 157/157 [00:04<00:00, 37.87Batch/s, loss=0.086203, top1=54.28%, top5=94.02%]






train Epoch: 11: 100%|██████████| 782/782 [01:03<00:00, 12.31Batch/s, loss=0.056937, top1=67.49%, top5=97.65%]
validation Epoch: 11: 100%|██████████| 157/157 [00:04<00:00, 37.93Batch/s, loss=0.083679, top1=54.67%, top5=94.01%]






train Epoch: 12: 100%|██████████| 782/782 [01:03<00:00, 12.28Batch/s, loss=0.048464, top1=72.20%, top5=98.56%]
validation Epoch: 12: 100%|██████████| 157/157 [00:04<00:00, 37.78Batch/s, loss=0.090819, top1=55.47%, top5=93.95%]






train Epoch: 13: 100%|██████████| 782/782 [01:04<00:00, 12.13Batch/s, loss=0.038292, top1=77.97%, top5=99.23%]
validation Epoch: 13: 100%|██████████| 157/157 [00:04<00:00, 37.47Batch/s, loss=0.093736, top1=56.61%, top5=93.66%]






train Epoch: 14: 100%|██████████| 782/782 [01:03<00:00, 12.37Batch/s, loss=0.027211, top1=84.24%, top5=99.69%]
validation Epoch: 14: 100%|██████████| 157/157 [00:03<00:00, 40.82Batch/s, loss=0.111820, top1=56.15%, top5=93.56%]






train Epoch: 15: 100%|██████████| 782/782 [00:59<00:00, 13.16Batch/s, loss=0.017911, top1=89.58%, top5=99.90%]
validation Epoch: 15: 100%|██████████| 157/157 [00:03<00:00, 40.79Batch/s, loss=0.126546, top1=57.13%, top5=93.44%]






train Epoch: 16: 100%|██████████| 782/782 [00:59<00:00, 13.25Batch/s, loss=0.011247, top1=93.64%, top5=99.98%] 
validation Epoch: 16: 100%|██████████| 157/157 [00:03<00:00, 40.95Batch/s, loss=0.139922, top1=56.95%, top5=93.33%]






train Epoch: 17: 100%|██████████| 782/782 [00:58<00:00, 13.28Batch/s, loss=0.006937, top1=96.09%, top5=99.99%] 
validation Epoch: 17: 100%|██████████| 157/157 [00:03<00:00, 41.09Batch/s, loss=0.154883, top1=57.74%, top5=93.52%]






train Epoch: 18: 100%|██████████| 782/782 [00:58<00:00, 13.27Batch/s, loss=0.004388, top1=97.59%, top5=100.00%]
validation Epoch: 18: 100%|██████████| 157/157 [00:03<00:00, 41.04Batch/s, loss=0.163486, top1=58.33%, top5=93.49%]






train Epoch: 19: 100%|██████████| 782/782 [01:02<00:00, 12.61Batch/s, loss=0.003146, top1=98.34%, top5=100.00%]
validation Epoch: 19: 100%|██████████| 157/157 [00:03<00:00, 39.80Batch/s, loss=0.168109, top1=58.56%, top5=93.75%]






train Epoch: 20: 100%|██████████| 782/782 [01:00<00:00, 13.00Batch/s, loss=0.002476, top1=98.74%, top5=100.00%]
validation Epoch: 20: 100%|██████████| 157/157 [00:03<00:00, 40.97Batch/s, loss=0.169486, top1=58.58%, top5=93.85%]






