In [1]:
import torch
import torchvision
import torch.optim as optim
from tqdm import tqdm

import torchvision.transforms as transforms
from helpers import *
from networks import *



## CIFAR 10

In [2]:
# load CIFAR10 dataset
transform_cifar10 = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_loader,test_loader = load_dataset(dataset_name='CIFAR10',batch_size = 100, transform=transform_cifar10,num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# build network
net = CNN1()

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [9]:
def train_model(train_loader, optimizer, criterion, max_iter= 1000):
    losses = []
    for epoch in tqdm(range(max_iter)):
        for i,data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
        losses.append(running_loss)
    return losses


In [10]:
losses = train_model(train_loader,optimizer,criterion,max_iter = 10)

 10%|█         | 1/10 [00:16<02:32, 16.90s/it]

[0] loss: 850.808


 20%|██        | 2/10 [00:33<02:15, 16.99s/it]

[1] loss: 814.922


 30%|███       | 3/10 [00:51<01:59, 17.10s/it]

[2] loss: 786.597


 40%|████      | 4/10 [01:07<01:41, 16.87s/it]

[3] loss: 762.337


 50%|█████     | 5/10 [01:24<01:24, 16.91s/it]

[4] loss: 741.346


 60%|██████    | 6/10 [01:40<01:05, 16.47s/it]

[5] loss: 725.589


 70%|███████   | 7/10 [01:57<00:49, 16.67s/it]

[6] loss: 708.673


 80%|████████  | 8/10 [02:14<00:33, 16.78s/it]

[7] loss: 694.237


 90%|█████████ | 9/10 [02:31<00:16, 16.78s/it]

[8] loss: 680.477


100%|██████████| 10/10 [02:47<00:00, 16.75s/it]

[9] loss: 666.367





In [None]:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

In [None]:
dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

In [None]:
net = Net()
net.load_state_dict(torch.load(PATH))

In [None]:
outputs = net(images)

In [None]:
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))

In [None]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')