In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch import nn

In [2]:
batch_size = 64
epochs = 3
lr = 0.01
weight_decay = 0.0001
device = torch.device('cuda')
plot_every = 5
valid_every = 20
save_every_epoch = 1

In [3]:
transform = transforms.Compose([transforms.Resize(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data = CIFAR100(root='./', train=True, transform=transform, download=True)
test_data = CIFAR100(root='./', train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=12)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=12)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model = nn.Sequential(model, nn.Linear(1000, 100)).to(device)

In [5]:
optimizer = torch.optim.SGD(model.parameters(), lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

In [6]:
from visdom import Visdom
vis = Visdom(port=6006)

Setting up a new session...


In [7]:
def reduce_lr():
    lr = optimizer.param_groups[0]['lr'] * 0.1
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
for epoch in range(epochs):
    running_loss = 0
    for i, (img, label) in enumerate(train_loader, 385):
        img, label = img.to(device), label.to(device)
        pred = model(img)
        loss = criterion(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if (i+1) % plot_every == 0:
            print(f'epoch: {epoch+1}/{epochs}, iter: {i+1}/{len(train_loader)}, loss: {running_loss/plot_every}')
            vis.line(Y=[running_loss/plot_every], X=[i+1+epoch*len(train_loader)], win='train', name='train', update='append',
                        opts={'showlegend': True,
                              'xlabel': "iter",
                              'ylabel': "loss"})
            running_loss = 0
            
        if (i+1) % valid_every == 0:
            with torch.no_grad():
                valid_loss = 0
                for img, label in test_loader:
                    img, label = img.to(device), label.to(device)
                    pred = model(img)
                    loss = criterion(pred, label)
                    valid_loss += loss.item()
            
                print(f'epoch: {epoch+1}/{epochs}, iter: {i+1}/{len(train_loader)}, valid_loss: {valid_loss/len(test_loader)}')
                vis.line(Y=[valid_loss/len(test_loader)], X=[i+1+epoch*len(train_loader)], win='train',
                        name='test', update='append', opts={'showlegend': True})

In [11]:
reduce_lr()

In [13]:
model.eval()
correct = 0
total = 0
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=12)
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predict = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()

print('Accuracy of the network on the 10000 test images: {0}%'.format(100 * correct / total))

Accuracy of the network on the 10000 test images: 91.05%


In [14]:
torch.save(model.state_dict(), f'./model_check_points/finished.pt')