In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.491, 0.482, 0.446],
            std= [0.247, 0.243, 0.261]
        )]) # TODO: Automate calculation given a dataset

trainset = torchvision.datasets.CIFAR10(root='./raid/data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./raid/data', train=False,
                                       download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

In [None]:
from torchvision import models
import torch.nn as nn
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(in_features=512, out_features=10)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

In [None]:
model.train()
for epoch in range(15):  # loop over the dataset multiple times
    print("########## {} ##########".format(epoch+1))
    train_loss = 0.0
    total  = 0
    correct = 0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        if torch.cuda.is_available():
            model.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        if i % 200 == 199:    # print every 200 mini-batches
            print("Loss: {} | Acc: {} | {}/{}".format(train_loss/200, 100.*correct/total, correct, total))
            train_loss = 0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), "./raid/trained_models/pred_class.pt")