In [24]:
import sys
import os
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data.dataset import Dataset
from PIL import Image


In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [38]:
Epoch = 100

## Dataloader

In [23]:
datalist = []
facePath = './dataset/AFDB_face_dataset'
maskPath = './dataset/AFDB_masked_face_dataset'

## Get face data
for root, dirs, files in os.walk(facePath):
    if files == []:
        continue
    path1 = os.path.join(root, files[0])
    path2 = os.path.join(root, files[1])
    datalist.append([path1, 0])
    datalist.append([path2, 0])
# print(len(datalist))

## Get masked data
for root, dirs, files in os.walk(maskPath):
    if files == []:
        continue
    path1 = os.path.join(root, files[0])
    if len(files) > 1:
        path2 = os.path.join(root, files[1])
    datalist.append([path1, 1])
    datalist.append([path2, 1])

# print(len(datalist))

920
1882


In [55]:
class faceMaskerDataset(Dataset):
    def __init__(self, filelist, transformation=None):
        self.transforms = transformation
        self.filelist = filelist
    
    def __getitem__(self, index):
        imagePath = self.filelist[index][0]
        label = self.filelist[index][1]
        img = Image.open(imagePath).resize([32,32])
        if self.transforms is not None:
            img = self.transforms(img)
        return (img, label)
    
    def __len__(self):
        return len(self.filelist)

transformation = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [56]:
dataset = faceMaskerDataset(datalist, transformation)
img, label = dataset[1]

## Conv Net

In [57]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1);
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
net = Net()

## Optimization Method

In [58]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
trainloader = torch.utils.data.DataLoader(dataset, batch_size=4, 
                                         shuffle=True, num_workers=2)
for epoch in range(Epoch):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        optimizer.step()
        
        running_loss += loss.item()
        if i % 20 == 19:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 20))
            running_loss = 0.0
print('Finished Training')

[1,    20] loss: 4.403
[1,    40] loss: 4.405
[1,    60] loss: 4.399
[1,    80] loss: 4.398
[1,   100] loss: 4.410
[1,   120] loss: 4.396
[1,   140] loss: 4.394
[1,   160] loss: 4.413
[1,   180] loss: 4.404
[1,   200] loss: 4.396
[1,   220] loss: 4.394
[1,   240] loss: 4.406
[1,   260] loss: 4.397
[1,   280] loss: 4.400
[1,   300] loss: 4.406
[1,   320] loss: 4.397
[1,   340] loss: 4.405
[1,   360] loss: 4.403
[1,   380] loss: 4.399
[1,   400] loss: 4.405
[1,   420] loss: 4.412
[1,   440] loss: 4.392
[1,   460] loss: 4.407
[2,    20] loss: 4.398
[2,    40] loss: 4.395
[2,    60] loss: 4.404
[2,    80] loss: 4.402
[2,   100] loss: 4.403
[2,   120] loss: 4.408
[2,   140] loss: 4.405
[2,   160] loss: 4.401
[2,   180] loss: 4.402
[2,   200] loss: 4.398
[2,   220] loss: 4.404
[2,   240] loss: 4.405
[2,   260] loss: 4.402
[2,   280] loss: 4.398
[2,   300] loss: 4.401
[2,   320] loss: 4.398
[2,   340] loss: 4.402
[2,   360] loss: 4.401
[2,   380] loss: 4.410
[2,   400] loss: 4.411
[2,   420] 

[16,   120] loss: 4.404
[16,   140] loss: 4.395
[16,   160] loss: 4.409
[16,   180] loss: 4.410
[16,   200] loss: 4.397
[16,   220] loss: 4.404
[16,   240] loss: 4.393
[16,   260] loss: 4.415
[16,   280] loss: 4.395
[16,   300] loss: 4.395
[16,   320] loss: 4.385
[16,   340] loss: 4.393
[16,   360] loss: 4.401
[16,   380] loss: 4.409
[16,   400] loss: 4.408
[16,   420] loss: 4.401
[16,   440] loss: 4.403
[16,   460] loss: 4.406
[17,    20] loss: 4.402
[17,    40] loss: 4.404
[17,    60] loss: 4.399
[17,    80] loss: 4.407
[17,   100] loss: 4.394
[17,   120] loss: 4.399
[17,   140] loss: 4.414
[17,   160] loss: 4.398
[17,   180] loss: 4.408
[17,   200] loss: 4.398
[17,   220] loss: 4.410
[17,   240] loss: 4.406
[17,   260] loss: 4.394
[17,   280] loss: 4.399
[17,   300] loss: 4.394
[17,   320] loss: 4.408
[17,   340] loss: 4.410
[17,   360] loss: 4.385
[17,   380] loss: 4.410
[17,   400] loss: 4.401
[17,   420] loss: 4.402
[17,   440] loss: 4.399
[17,   460] loss: 4.399
[18,    20] loss

[31,    60] loss: 4.391
[31,    80] loss: 4.406
[31,   100] loss: 4.404
[31,   120] loss: 4.403
[31,   140] loss: 4.406
[31,   160] loss: 4.402
[31,   180] loss: 4.409
[31,   200] loss: 4.409
[31,   220] loss: 4.400
[31,   240] loss: 4.402
[31,   260] loss: 4.408
[31,   280] loss: 4.397
[31,   300] loss: 4.397
[31,   320] loss: 4.404
[31,   340] loss: 4.402
[31,   360] loss: 4.397
[31,   380] loss: 4.390
[31,   400] loss: 4.401
[31,   420] loss: 4.405
[31,   440] loss: 4.397
[31,   460] loss: 4.402
[32,    20] loss: 4.407
[32,    40] loss: 4.402
[32,    60] loss: 4.409
[32,    80] loss: 4.404
[32,   100] loss: 4.404
[32,   120] loss: 4.401
[32,   140] loss: 4.403
[32,   160] loss: 4.396
[32,   180] loss: 4.398
[32,   200] loss: 4.401
[32,   220] loss: 4.405
[32,   240] loss: 4.408
[32,   260] loss: 4.391
[32,   280] loss: 4.401
[32,   300] loss: 4.406
[32,   320] loss: 4.398
[32,   340] loss: 4.399
[32,   360] loss: 4.400
[32,   380] loss: 4.407
[32,   400] loss: 4.406
[32,   420] loss

[46,    20] loss: 4.395
[46,    40] loss: 4.403
[46,    60] loss: 4.396
[46,    80] loss: 4.415
[46,   100] loss: 4.388
[46,   120] loss: 4.401
[46,   140] loss: 4.392
[46,   160] loss: 4.402
[46,   180] loss: 4.401
[46,   200] loss: 4.416
[46,   220] loss: 4.408
[46,   240] loss: 4.412
[46,   260] loss: 4.396
[46,   280] loss: 4.395
[46,   300] loss: 4.406
[46,   320] loss: 4.402
[46,   340] loss: 4.403
[46,   360] loss: 4.414
[46,   380] loss: 4.388
[46,   400] loss: 4.396
[46,   420] loss: 4.409
[46,   440] loss: 4.406
[46,   460] loss: 4.397
[47,    20] loss: 4.396
[47,    40] loss: 4.412
[47,    60] loss: 4.395
[47,    80] loss: 4.401
[47,   100] loss: 4.402
[47,   120] loss: 4.408
[47,   140] loss: 4.389
[47,   160] loss: 4.404
[47,   180] loss: 4.400
[47,   200] loss: 4.403
[47,   220] loss: 4.409
[47,   240] loss: 4.405
[47,   260] loss: 4.404
[47,   280] loss: 4.400
[47,   300] loss: 4.399
[47,   320] loss: 4.405
[47,   340] loss: 4.405
[47,   360] loss: 4.400
[47,   380] loss

[60,   440] loss: 4.395
[60,   460] loss: 4.404
[61,    20] loss: 4.398
[61,    40] loss: 4.405
[61,    60] loss: 4.397
[61,    80] loss: 4.410
[61,   100] loss: 4.391
[61,   120] loss: 4.402
[61,   140] loss: 4.388
[61,   160] loss: 4.414
[61,   180] loss: 4.401
[61,   200] loss: 4.403
[61,   220] loss: 4.402
[61,   240] loss: 4.396
[61,   260] loss: 4.407
[61,   280] loss: 4.405
[61,   300] loss: 4.398
[61,   320] loss: 4.404
[61,   340] loss: 4.416
[61,   360] loss: 4.415
[61,   380] loss: 4.400
[61,   400] loss: 4.402
[61,   420] loss: 4.394
[61,   440] loss: 4.405
[61,   460] loss: 4.391
[62,    20] loss: 4.397
[62,    40] loss: 4.403
[62,    60] loss: 4.391
[62,    80] loss: 4.403
[62,   100] loss: 4.398
[62,   120] loss: 4.399
[62,   140] loss: 4.402
[62,   160] loss: 4.403
[62,   180] loss: 4.392
[62,   200] loss: 4.409
[62,   220] loss: 4.400
[62,   240] loss: 4.395
[62,   260] loss: 4.409
[62,   280] loss: 4.401
[62,   300] loss: 4.410
[62,   320] loss: 4.405
[62,   340] loss

[75,   400] loss: 4.398
[75,   420] loss: 4.398
[75,   440] loss: 4.399
[75,   460] loss: 4.399
[76,    20] loss: 4.400
[76,    40] loss: 4.395
[76,    60] loss: 4.406
[76,    80] loss: 4.394
[76,   100] loss: 4.399
[76,   120] loss: 4.401
[76,   140] loss: 4.398
[76,   160] loss: 4.404
[76,   180] loss: 4.396
[76,   200] loss: 4.399
[76,   220] loss: 4.405
[76,   240] loss: 4.408
[76,   260] loss: 4.412
[76,   280] loss: 4.403
[76,   300] loss: 4.408
[76,   320] loss: 4.405
[76,   340] loss: 4.401
[76,   360] loss: 4.406
[76,   380] loss: 4.397
[76,   400] loss: 4.397
[76,   420] loss: 4.398
[76,   440] loss: 4.399
[76,   460] loss: 4.407
[77,    20] loss: 4.406
[77,    40] loss: 4.400
[77,    60] loss: 4.398
[77,    80] loss: 4.405
[77,   100] loss: 4.402
[77,   120] loss: 4.393
[77,   140] loss: 4.393
[77,   160] loss: 4.406
[77,   180] loss: 4.405
[77,   200] loss: 4.395
[77,   220] loss: 4.403
[77,   240] loss: 4.400
[77,   260] loss: 4.403
[77,   280] loss: 4.408
[77,   300] loss