In [None]:
import os

def getlabel(file_name):
    return file_name.split('-')[1][:-4]
    

classes = sorted(tuple(set(getlabel(n) for n in os.listdir('doors'))))
num_classes = len(classes)

In [None]:
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim


import cv2

import numpy as np

import matplotlib.pyplot as plt

from tqdm import tqdm


device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

def load_image(img_path):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    height, width, _ = image.shape
    diff = height-width
    if diff > 0:
        image = cv2.copyMakeBorder(image, 0, 0, diff//2, diff//2, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
    elif diff < 0:
        image = cv2.copyMakeBorder(image, -diff//2, -diff//2, 0, 0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
    image = cv2.resize(image, (40, 40), interpolation=cv2.INTER_AREA)
    return image

class Doors(data.Dataset):
    def __init__(self, split):
        files = os.listdir('doors')[:2186]
        self.classes = {c: idx for idx, c in enumerate(classes)}
        N = len(files)
        self.split = split
        size = len(files)
        train, val, test = data.random_split(files, [int(0.7*N), int(0.2*N), N - int(0.7*N) - int(0.2*N)], generator=torch.Generator().manual_seed(42))
        items = {'train': train, 'val': val, 'test': test}
        self.items = items[split]


    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        name = self.items[idx]
        label = getlabel(name)
        img = load_image('doors/' + name)
        img = np.moveaxis(img, 2, 0)
        img = (img / 255) * 2 - 1
        # print(np.unique(img))
        # print(label)
        # plt.imshow(img)
        # plt.show()
        # print(torch.tensor(img).shape)
        return torch.tensor(img).float(), self.classes[label]

class Net(nn.Module):

    def __init__(self, classes):
        super(Net, self).__init__()
        self.num_classes = len(classes)

        self.conv1 = nn.Conv2d(3, 32, (3,3))
        self.conv2 = nn.Conv2d(32, 64, (3,3))
        self.fc1 = nn.Linear(8*8*64, self.num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc1(x)
        x = F.softmax(x, dim=-1)
        return x


model = Net(classes=classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
trainloader = data.DataLoader(Doors('train'), batch_size=16)
valloader = data.DataLoader(Doors('val'))


patience = 2
curr_patience = 0
best_loss = np.inf
for epoch in range(100):
    print('Epoch:', epoch)
    model.train()
    running_loss = 0.0
    for i, (x, y) in enumerate(tqdm(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = x.to(device), y.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 500 == 499:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    with torch.no_grad():
        val_loss = 0.0
        model.eval()
        for (x, y) in valloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = x.to(device), y.to(device)

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # print statistics
            val_loss += loss.item()
        

        print('Validation loss', val_loss)
        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0
        else:
            patience += 1
            if patience == 2:
                break
            

In [None]:
torch.save(model.state_dict(), 'model.pth')

In [None]:
testloader = data.DataLoader(Doors('test'))
testiter = iter(testloader)

In [None]:
example = next(testiter)

In [None]:
img, label = example
img_rgb = ((img.numpy() + 1)/2 * 255).astype(np.int)[0]
img_rgb = np.moveaxis(img_rgb, 0, 2)
plt.imshow(img_rgb)
print(f'{label}: {classes[label]}')


In [None]:
model.eval()
predict = model(img.to(device)).argmax().item()
print(f'{predict}: {classes[predict]}')