# Convolutional Neural Network: Cat vs Fish Classification

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

### Setting up data loaders

In [18]:
# sanity check
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [19]:
img_transformers = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
])

In [20]:
train_data_path = './train'
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=img_transformers, is_valid_file=check_image)

val_data_path = './val'
valid_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=img_transformers, is_valid_file=check_image)

test_data_path = './test'
test_data = torchvision.datasets.ImageFolder(root=test_data_path, transform=img_transformers, is_valid_file=check_image)


In [21]:
batch_size = 64

In [22]:
# data loaders
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

## Convolutional Neural Network

In [23]:
class CNN(nn.Module):
    def __init__(self, num_classes=2):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64,192, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
    )
    
        self.avgpool = nn.AdaptiveAvgPool2d((6,6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
    )

    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = self.classifier(x)
        return x

In [24]:
cnnnet = CNN()

In [25]:
optimizer = optim.Adam(cnnnet.parameters(), lr=0.001)

In [26]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

cnnnet.to(device)

CNN(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): L

In [27]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
     for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [28]:
train(cnnnet, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, epochs=10, device=device)

Epoch: 0, Training Loss: 689.92, Validation Loss: 0.71, accuracy = 0.22
Epoch: 1, Training Loss: 0.71, Validation Loss: 0.69, accuracy = 0.22
Epoch: 2, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 3, Training Loss: 0.70, Validation Loss: 0.70, accuracy = 0.22
Epoch: 4, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 5, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 6, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 7, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 8, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
Epoch: 9, Training Loss: 0.69, Validation Loss: 0.70, accuracy = 0.22
