In [98]:
import glob
import re

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import numpy as np

import torch
import torchvision
from torchvision import transforms, models
from torch.autograd import Variable
import torch.nn as nn

from PIL import Image

In [99]:
class TinaNet(nn.Module):
    def __init__(self):
        super(TinaNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3, 60, kernel_size=5, stride=1, padding=2),
                               nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(60, 120, kernel_size=5, stride=1, padding=2),
                               nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc1 = nn.Linear(27000, 1000)
        self.fc2 = nn.Linear(1000, 2)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
#         out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [100]:
# Load the Cats Dataset.
images = np.array(glob.glob('/Users/carlos_guzman/Pictures/Cats/*/*.png'))
# images = np.array([img for img in images if 'Archie' not in img or 'Ella' not in img])

In [101]:
## Names of each cat to classify.
# classes = ['Tina', 'Ben', 'Jojo', 'Lluvia', 'Stretch']
classes = ['Tina', 'Not Tina']

## Set random seed and shuffle images.
np.random.seed(4)
np.random.shuffle(images)

## Create a list of labels to each image.
y = []
for img in images:
    if re.search('Tina', img):
        y.append('Tina')
    else:
        y.append('Not Tina')

Y = []
for img in images:
    if re.search('Tina', img):
        Y.append(torch.tensor([1]))
    else:
        Y.append(torch.tensor([0]))

In [120]:
## Open all images and transform to torch tensors.
X = []
transform = transforms.Compose([transforms.Resize((60,60)), transforms.ToTensor()])
for img in images:
#     X.append(mpimg.imread(img))
    X.append(transform(Image.open(img)))
for i in range(len(X)):
    if X[i].size(0) == 4:
        print(images[i])
X = [X[i] for i in range(len(X)) if X[i].size(0) != 4]

/Users/carlos_guzman/Pictures/Cats/Tina/Tina_105.png
/Users/carlos_guzman/Pictures/Cats/Tina/Tina_91.png


In [160]:
model = TinaNet()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train the model
total_step = len(X)
loss_list = []
acc_list = []
num_epochs = 100

for epoch in range(num_epochs):
    print('epoch', epoch)
    for i, (x,y) in enumerate(zip(X,Y)):
        ## Run the forward pass
        out = model(x.unsqueeze(0))

        loss = criterion(out, y)
        loss_list.append(loss.item())

        # Backprop and perform Adam optimisation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#         # Track the accuracy
#         total = y.size(0)
#         _, predicted = torch.max(out.data, 1)
#         correct = (predicted == y).sum().item()
#         acc_list.append(correct / total)

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9


KeyboardInterrupt: 

In [None]:
# ## Quick display of 
# print(X[0].size())
# plt.imshow(X[0].view(X[0].shape[1], X[0].shape[2], X[0].shape[0]))

In [None]:
# ## Display tensors using PIL and MatPlot Lib
# trans = transforms.ToPILImage()
# for x,y in zip(X, Y):
#     plt.imshow(trans(x.squeeze()))
#     plt.title(y)
#     plt.show()

In [None]:
test_images = np.array(glob.glob('/Users/carlos_guzman/Pictures/Cats/*.png'))
## Open all images and transform to torch tensors.
Test_X, Test_Y = [], []
for i in range(len(test_images)):
    if re.search('Tina', test_images[i]):
        Test_Y.append(torch.tensor([1]))
    elif re.search('Jojo|Steven|Gryif', test_images[i]):
        Test_Y.append(torch.tensor([0]))
    else:
        continue
    Test_X.append(transform(Image.open(test_images[i])))

Test_X = [Test_X[i] for i in range(len(Test_X)) if Test_X[i].size(0) != 4]

In [None]:
for x, y in zip(Test_X, Test_Y):
    out = model(x.unsqueeze(0))
    _, predicted = torch.max(out.data, 1)
    print('pred', predicted)
    print('y', y)

In [None]:
## Display tensors using PIL and MatPlot Lib
trans = transforms.ToPILImage()
for x,y in zip(Test_X, Test_Y):
    plt.imshow(trans(x.squeeze()))
    plt.title(y)
    plt.show()