# Image classification

In [1]:
import torch
import torch.utils.data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image

## Load data

In [2]:
train_data_path = '../pytorchupandrunning/chapter2/train'
val_data_path = '../pytorchupandrunning/chapter2/val'
test_data_path = '../pytorchupandrunning/chapter2/test'

In [3]:
img_transforms = transforms.Compose([transforms.Resize((64, 64)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                          std = [0.229, 0.224, 0.225])
                                     ])

In [4]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [5]:
train_data = torchvision.datasets.ImageFolder(root = train_data_path, transform = img_transforms, is_valid_file = check_image)
test_data = torchvision.datasets.ImageFolder(root = test_data_path, transform = img_transforms, is_valid_file = check_image)
val_data = torchvision.datasets.ImageFolder(root = val_data_path, transform = img_transforms, is_valid_file = check_image)

In [6]:
batch_size = 64
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size = batch_size) 
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size) 

## Neural network

In [7]:
import torch.nn as nn
import torch.nn.functional as F

##  FC

In [8]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50, 2)
        
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [9]:
simpleNet = SimpleNet()

In [10]:
import torch.optim as optim

In [11]:
optimizer = optim.Adam(simpleNet.parameters(), lr = 0.001)

In [12]:
def train(net, optimizer, loss_fn, train_loader, epochs = 20):
    for epoch in range(epochs):
        training_loss = 0.
        valid_loss = 0.
        for batch in train_loader:
            inputs, labels = batch
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
        training_loss /= len(train_loader.dataset)        
        
        print('Epoch: {}, Training Loss: {:.2f}'.format(epoch + 1, training_loss))

In [14]:
train(simpleNet, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, epochs = 5)

Epoch: 0, Training Loss: 0.05
Epoch: 1, Training Loss: 0.03
Epoch: 2, Training Loss: 0.01
Epoch: 3, Training Loss: 0.01
Epoch: 4, Training Loss: 0.01


In [None]:
def test(net, optimizer, loss_fn, val_loader):
    num_correct = 0
    num_examples = 0
    for batch in val_loader:
        inputs, labels = batch
        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)

In [26]:
num_correct = 0
num_examples = 0
for batch in val_data_loader:
    inputs, labels = batch
    outputs = simpleNet(inputs)
    _, predicted = torch.max(outputs.data, 1)    
    num_correct += (predicted == labels).sum().item()
    num_examples += labels.size(0)
    
print('Accuracy of the network on the test images: %d %%' % (100 * num_correct / num_examples))

Accuracy of the network on the test images: 71 %


## Test your image

In [44]:
labels = ['cat', 'fish']
img = Image.open(os.path.join(test_data_path, 'fish', '1394413521_27536c0a8f.jpg'))
img = img_transforms(img)
print(img.shape)
img = img.unsqueeze(0)
print(img.shape)
prediction = simpleNet(img)
print(prediction)
probability = F.softmax(prediction, dim = 1)[0]
print(probability, probability.shape)
prediction = prediction.argmax()
print(labels[prediction], probability[prediction].item())

torch.Size([3, 64, 64])
torch.Size([1, 3, 64, 64])
tensor([[-2.2881,  0.5201]], grad_fn=<AddmmBackward>)
tensor([0.0569, 0.9431], grad_fn=<SelectBackward>) torch.Size([2])
fish 0.9431185126304626
