In [1]:
import torch
import torch.nn as nn

from resnet_256_out import resnet20
# from resnet_64_out import resnet20

inplane=64

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = resnet20()

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), 0.1,
                            momentum=0.9,
                            weight_decay=1e-4)

lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=[100, 150], last_epoch=-1)



In [2]:
import torchvision.transforms as transforms
import torchvision

from torch.utils.data import DataLoader

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transform)

train_loader = DataLoader(trainset, batch_size=128, shuffle=True, pin_memory=True)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
print_freq = 20

def validate(dataloader, model, criterion):
    model.eval()

    correct_count = 0
    total = 0

    run_loss = 0

    with torch.no_grad():
        for i, (input, target) in enumerate(dataloader):
            target = target.to(device)
            input = input.to(device)

            out = model(input)
            loss = criterion(out, target)

            run_loss += loss

            _, predicted = out.max(1)
            total += target.size(0)
            correct_count += predicted.eq(target).sum().item()

            if i % print_freq == 0:
                print('Testing: {} / {}'.format(i, len(dataloader)))
                print('Loss {}'.format(run_loss / print_freq))
                print('Correct image {} / {}'.format(correct_count, total))

                run_loss = 0

        print('Completed: total accuracy {} / {}'.format(correct_count, total))
        return correct_count

def train(dataloader, model, criterion, optimizer, epoch):
    model.train()
    
    running_loss = 0
    for i, (input, target) in enumerate(dataloader):
        target = target.to(device)
        input = input.to(device)

        output = model(input)
        loss = criterion(output, target)

        running_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % print_freq == 0:
            print('Epoch: {}; {} / {}'.format(epoch, i, len(dataloader)))
            print('Loss: {}'.format(running_loss / print_freq))

            running_loss = 0

In [4]:
best_acc = 0

for epoch in range(200):
    print('Current LR: {}'.format(optimizer.param_groups[0]['lr']))
    print('Starting Epoch {}'.format(epoch))

    train(train_loader, model, criterion, optimizer, epoch)
    lr_scheduler.step()

    if epoch % 5 == 0:
        acc = validate(test_loader, model, criterion)

        if best_acc < acc:
            from datetime import datetime
            filename = '{}.pt'.format(datetime.now().strftime('%Y-%m-%d-%H-%M'))
            print('Saving model with accuracy {} / {} to {}'.format(acc, 10000, filename))
            torch.save(model.state_dict(), filename)

            best_acc = acc

Current LR: 0.1
Starting Epoch 0
Epoch: 0; 0 / 391
Loss: 0.14618851244449615
Epoch: 0; 20 / 391
Loss: 5.78220796585083


KeyboardInterrupt: 