In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import torch
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

In [2]:
def hot_encode(x):
    num_classes = 2
    return torch.eye(num_classes)[x]

In [3]:
data_transforms = transforms.Compose([transforms.Resize((100,100)),             # resize the input to 224x224
                                      transforms.ToTensor(),              # put the input to tensor format
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # normalize the input
                                    ])
## load dataset from the folder
dataset = datasets.ImageFolder("images",transform=data_transforms,target_transform=hot_encode)

## split data
train_size = int(np.floor(len(dataset)*0.8))
test_size = len(dataset) - train_size
train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])

## create data loader for both training and testing data
train_loader = torch.utils.data.DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=10, shuffle=True)

## Build Neural Network

In [7]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return x        
    
class classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4608,265)
        self.fc2 = nn.Linear(265,2)
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x,dim = 1)

In [10]:
vgg16 = models.vgg16(pretrained=True)
for param in vgg16.parameters():
    param.requires_grad = False
vgg16.avgpool = Identity()
vgg16.classifier = classifier()
vgg16.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [None]:
device = "cuda" if torch.cuda.is_available()  else "cpu"

## Training

In [12]:
EPOCH = 10
LR = 0.001
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(vgg16.parameters(), lr=LR)

In [23]:
def train(model):
    for epoch in range(EPOCH):
        losses = []
        accuracies = []
        for batch in tqdm(train_loader):
            X_batch ,Y_batch = batch[0].view(-1,3,100,100).to(device), batch[1].to(device)

            output = model(X_batch)
            loss = loss_function(output, Y_batch)
            losses.append(loss)
            vgg16.zero_grad()
            loss.backward()
            optimizer.step()
            ## find accuarcy
            matches = [torch.argmax(a) == torch.argmax(b) for a,b in zip(output, Y_batch)]
            accuracy = matches.count(True)/len(matches)
            accuracies.append(accuracy)
        print(f"Loss at epoch {epoch} is {sum(losses)/len(losses):.5f}")
        print(f"Accuracy at epoch {epoch} is {sum(accuracies)/len(accuracies):.5f}")


In [24]:
train(vgg16)

100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:11<00:00,  6.56it/s]
  0%|▏                                                                                 | 1/469 [00:00<00:59,  7.84it/s]

Loss at epoch 0 is 0.01515
Accuracy at epoch 0 is 0.98166


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:10<00:00,  6.66it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:13,  6.37it/s]

Loss at epoch 1 is 0.01176
Accuracy at epoch 1 is 0.98678


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:10<00:00,  6.70it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:04,  7.25it/s]

Loss at epoch 2 is 0.01014
Accuracy at epoch 2 is 0.98742


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:10<00:00,  6.68it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:00,  7.69it/s]

Loss at epoch 3 is 0.01319
Accuracy at epoch 3 is 0.98443


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:10<00:00,  6.69it/s]
  0%|▏                                                                                 | 1/469 [00:00<00:59,  7.81it/s]

Loss at epoch 4 is 0.01042
Accuracy at epoch 4 is 0.98763


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:09<00:00,  6.72it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:07,  6.94it/s]

Loss at epoch 5 is 0.01008
Accuracy at epoch 5 is 0.98849


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:09<00:00,  6.75it/s]
  0%|                                                                                          | 0/469 [00:00<?, ?it/s]

Loss at epoch 6 is 0.01619
Accuracy at epoch 6 is 0.98145


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:09<00:00,  6.76it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:07,  6.90it/s]

Loss at epoch 7 is 0.01298
Accuracy at epoch 7 is 0.98593


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:09<00:00,  6.77it/s]
  0%|▏                                                                                 | 1/469 [00:00<01:09,  6.71it/s]

Loss at epoch 8 is 0.01056
Accuracy at epoch 8 is 0.98806


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [01:09<00:00,  6.76it/s]

Loss at epoch 9 is 0.00976
Accuracy at epoch 9 is 0.98849





In [25]:
def test(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            X_batch ,Y_batch = batch[0].view(-1,3,100,100).to(device), batch[1].to(device)
            real_class = torch.argmax(Y_batch).to(device)
            net_out = vgg16(X_batch.view(-1, 3, 100, 100).to(device))[0]

            predicted_class = torch.argmax(net_out)
            if predicted_class == real_class:
                correct += 1
            total += 1
    print("Accuracy:", round(correct/total,3))

In [26]:
test(vgg16)

100%|████████████████████████████████████████████████████████████████████████████████| 118/118 [00:18<00:00,  6.45it/s]

Accuracy: 0.983



