In [1]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from utils import Trainer, build_teacher_model
from teacher_model import ResNet50

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 10

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='../datasets', train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root='../datasets', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
model = build_teacher_model(ResNet50, num_classes=10)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [7]:
trainer = Trainer(model, {'train':train_loader, 'validation':val_loader}, criterion, optimizer, None, num_epochs, (1, 5), 'pretrained/cifar10_teacher_resnet.pth', device)

In [8]:
trainer.train()

train Epoch: 1: 100%|██████████| 782/782 [02:53<00:00,  4.50Batch/s, loss=0.047543, top1=73.91%, top5=97.87%]
validation Epoch: 1: 100%|██████████| 157/157 [00:15<00:00, 10.19Batch/s, loss=0.039286, top1=78.67%, top5=98.71%]






train Epoch: 2: 100%|██████████| 782/782 [02:54<00:00,  4.47Batch/s, loss=0.027283, top1=84.93%, top5=99.43%]
validation Epoch: 2: 100%|██████████| 157/157 [00:15<00:00, 10.29Batch/s, loss=0.026519, top1=85.26%, top5=99.47%]






train Epoch: 3: 100%|██████████| 782/782 [02:53<00:00,  4.52Batch/s, loss=0.020852, top1=88.36%, top5=99.69%]
validation Epoch: 3: 100%|██████████| 157/157 [00:14<00:00, 10.59Batch/s, loss=0.024294, top1=86.76%, top5=99.58%]






train Epoch: 4: 100%|██████████| 782/782 [02:45<00:00,  4.72Batch/s, loss=0.016361, top1=90.94%, top5=99.79%]
validation Epoch: 4: 100%|██████████| 157/157 [00:14<00:00, 10.80Batch/s, loss=0.023306, top1=87.28%, top5=99.61%]






train Epoch: 5: 100%|██████████| 782/782 [02:48<00:00,  4.63Batch/s, loss=0.013402, top1=92.58%, top5=99.90%]
validation Epoch: 5: 100%|██████████| 157/157 [00:16<00:00,  9.75Batch/s, loss=0.021524, top1=88.55%, top5=99.54%]






train Epoch: 6: 100%|██████████| 782/782 [02:46<00:00,  4.69Batch/s, loss=0.010330, top1=94.24%, top5=99.93%]
validation Epoch: 6: 100%|██████████| 157/157 [00:15<00:00, 10.44Batch/s, loss=0.027616, top1=86.91%, top5=99.43%]






train Epoch: 7: 100%|██████████| 782/782 [02:52<00:00,  4.54Batch/s, loss=0.008646, top1=95.10%, top5=99.96%] 
validation Epoch: 7: 100%|██████████| 157/157 [00:15<00:00, 10.24Batch/s, loss=0.025704, top1=88.13%, top5=99.61%]






train Epoch: 8: 100%|██████████| 782/782 [02:50<00:00,  4.59Batch/s, loss=0.006975, top1=96.06%, top5=99.98%] 
validation Epoch: 8: 100%|██████████| 157/157 [00:16<00:00,  9.80Batch/s, loss=0.024061, top1=88.63%, top5=99.63%]






train Epoch: 9: 100%|██████████| 782/782 [02:52<00:00,  4.54Batch/s, loss=0.006239, top1=96.56%, top5=99.97%]
validation Epoch: 9: 100%|██████████| 157/157 [00:15<00:00, 10.46Batch/s, loss=0.032146, top1=85.97%, top5=99.63%]






train Epoch: 10: 100%|██████████| 782/782 [02:55<00:00,  4.47Batch/s, loss=0.004889, top1=97.36%, top5=99.99%] 
validation Epoch: 10: 100%|██████████| 157/157 [00:15<00:00, 10.22Batch/s, loss=0.021441, top1=90.01%, top5=99.67%]






