In [1]:
import glob
import os
import pathlib

import numpy as np
import pandas as pd

from util.image import unnormalize

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import functional as F

from torchvision import transforms
from torchinfo import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt

import albumentations as A
import torch.nn.functional as F

from PIL import Image

import torchmetrics
from torchvision.utils import save_image, make_grid
import cv2
from util.io import load_ckpt

from util.loss import  InpaintingLoss
import os, glob

import efficientunet
import random

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_path = 'Datasets/Mask detection/archive/Face Mask Dataset/'
train_dir = dataset_path+'Train/'
val_dir = dataset_path+'Validation/'
test_dir = dataset_path+'Test/'

In [3]:
sizes = (64, 64)

rescale_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(sizes, antialias= False)
])

In [4]:
torch.hstack([torch.ones(5), torch.zeros(4)])

tensor([1., 1., 1., 1., 1., 0., 0., 0., 0.])

In [5]:
def get_files(path):
    mask_files = glob.glob(path + 'WithMask/' +'*.png')
    nomask_files = glob.glob(path + 'WithoutMask/' +'*.png')

    mask_images = [rescale_transform(Image.open(x)) for x in tqdm(mask_files)]
    unmasked_images = [rescale_transform(Image.open(x)) for x in tqdm(nomask_files)]

    mask_labels = torch.ones(len(mask_images))
    unmask_labels = torch.zeros(len(unmasked_images))

    mask_images = torch.stack(mask_images)
    unmasked_images = torch.stack(unmasked_images)
    images = torch.vstack([mask_images, unmasked_images])
    labels = torch.hstack([mask_labels, unmask_labels])

    return images, labels

In [6]:
train_images, train_labels = get_files(train_dir)

100%|██████████| 5000/5000 [00:17<00:00, 280.50it/s]
100%|██████████| 5000/5000 [00:14<00:00, 341.31it/s]


In [7]:
val_images, val_labels = get_files(val_dir)

100%|██████████| 400/400 [00:01<00:00, 311.74it/s]
100%|██████████| 400/400 [00:01<00:00, 390.79it/s]


In [58]:
normalisation = transforms.Normalize(0.5, 0.5)

augmentation = transforms.RandomAffine(
    30, (0.15, 0.15), (0.8, 1.2), 10
)

In [59]:
class MaskDataset(Dataset):
    def __init__(self, images, labels, augmentation = None):
        self.images = images
        self.augmentation = augmentation
        self.labels = labels

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = normalisation(self.images[idx])
        if(self.augmentation is not None):
            img = self.augmentation(img)
        return img, self.labels[idx]

In [60]:
train_dataset = MaskDataset(train_images, train_labels, augmentation = augmentation)
val_dataset = MaskDataset(val_images, val_labels)

In [61]:
BATCH_SIZE = 250

train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle = True)
val_dataloader   = DataLoader(val_dataset, BATCH_SIZE, shuffle = True)

In [62]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 5, 3)        # out = 62
        self.conv2 = nn.Conv2d(5, 10, 3)       # in = 31, out = 14
        self.conv3 = nn.Conv2d(10, 10, 3)       # in = 12, out = 6

        self.pool = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(10 * 6 * 6, 64)
        self.o_n = nn.Linear(64, 1)


        self.flatten = nn.Flatten()
        self.activation = nn.ReLU()

    def forward(self, inpt):
        out = self.activation(self.conv1(inpt))
        out = self.pool(out)
        
        out = self.activation(self.conv2(out))
        out = self.pool(out)

        out = self.activation(self.conv3(out))
        out = self.pool(out)

        out = self.flatten(out)

        out = self.activation(self.fc1(out))
        out = self.o_n(out)

        return out

In [63]:
class EarlyStopping:

    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [64]:
torchinfo.summary(CNN(), (5, 3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [5, 1]                    --
├─Conv2d: 1-1                            [5, 5, 62, 62]            140
├─ReLU: 1-2                              [5, 5, 62, 62]            --
├─MaxPool2d: 1-3                         [5, 5, 31, 31]            --
├─Conv2d: 1-4                            [5, 10, 29, 29]           460
├─ReLU: 1-5                              [5, 10, 29, 29]           --
├─MaxPool2d: 1-6                         [5, 10, 14, 14]           --
├─Conv2d: 1-7                            [5, 10, 12, 12]           910
├─ReLU: 1-8                              [5, 10, 12, 12]           --
├─MaxPool2d: 1-9                         [5, 10, 6, 6]             --
├─Flatten: 1-10                          [5, 360]                  --
├─Linear: 1-11                           [5, 64]                   23,104
├─ReLU: 1-12                             [5, 64]                   --
├─Linear

In [66]:
EPOCHS = 25

train_acc = torchmetrics.classification.BinaryAccuracy().to(device)
val_acc = torchmetrics.classification.BinaryAccuracy().to(device)

model = CNN().to(device)
optim = torch.optim.Adam(model.parameters(), lr = 3e-3)
criterion = torch.nn.BCELoss()
early_stopping = EarlyStopping(patience=3, verbose=True, path = 'mask_model.pth')


for epoch_num in range(EPOCHS):
    train_loss = 0
    i = 0

    bar = tqdm(train_dataloader)
    for img, label in bar:
        i+=1
        optim.zero_grad()

        img = img.to(device)
        label = label.to(device).unsqueeze(1)
        predictions = F.sigmoid(model(img))

        batch_loss = criterion(predictions, label)

        train_acc(predictions, label)
        batch_loss.backward()
        optim.step()

        train_loss+= batch_loss.item()
        bar.set_description_str("Training loss: {:.4f}, accuracy = {:.4f}".format(train_loss/i, train_acc.compute()))

    train_loss/=i

    
    with torch.no_grad():
        val_loss = 0
        i = 0
        bar = tqdm(val_dataloader)
        for img, label in bar:
            i+=1
            optim.zero_grad()

            img = img.to(device)
            label = label.to(device).unsqueeze(1)
            predictions = F.sigmoid(model(img))

            batch_loss = criterion(predictions, label)

            
            val_acc(predictions, label)
            val_loss+= batch_loss.item()
            bar.set_description_str("Validation loss: {:.4f}, accuracy = {:.4f}".format(val_loss/i, val_acc.compute()))

        val_loss/=i


    print("Epoch [{}/{}], Train Loss: {:.4f}, Train Accuracy: {:.4f}".format(epoch_num+1, EPOCHS, train_loss, train_acc.compute()))
    print("Epoch [{}/{}], Val Loss: {:.4f}, Val Accuracy: {:.4f}".format(epoch_num+1, EPOCHS, val_loss, val_acc.compute()))
    early_stopping(val_loss, model)

    train_acc.reset()
    val_acc.reset()

    if early_stopping.early_stop:
        print("Early stopping")
        print('-'*60)
        break

Training loss: 0.4149, accuracy = 0.7985: 100%|██████████| 40/40 [00:08<00:00,  4.79it/s]
Validation loss: 0.2145, accuracy = 0.9225: 100%|██████████| 4/4 [00:00<00:00, 52.33it/s]


Epoch [1/25], Train Loss: 0.4149, Train Accuracy: 0.7985
Epoch [1/25], Val Loss: 0.2145, Val Accuracy: 0.9225
Validation loss decreased (inf --> 0.214532).  Saving model ...


Training loss: 0.2544, accuracy = 0.9038: 100%|██████████| 40/40 [00:08<00:00,  4.96it/s]
Validation loss: 0.1746, accuracy = 0.9237: 100%|██████████| 4/4 [00:00<00:00, 50.48it/s]


Epoch [2/25], Train Loss: 0.2544, Train Accuracy: 0.9038
Epoch [2/25], Val Loss: 0.1746, Val Accuracy: 0.9237
Validation loss decreased (0.214532 --> 0.174557).  Saving model ...


Training loss: 0.2170, accuracy = 0.9187: 100%|██████████| 40/40 [00:08<00:00,  4.84it/s]
Validation loss: 0.1281, accuracy = 0.9463: 100%|██████████| 4/4 [00:00<00:00, 52.71it/s]


Epoch [3/25], Train Loss: 0.2170, Train Accuracy: 0.9187
Epoch [3/25], Val Loss: 0.1281, Val Accuracy: 0.9463
Validation loss decreased (0.174557 --> 0.128084).  Saving model ...


Training loss: 0.1803, accuracy = 0.9322: 100%|██████████| 40/40 [00:08<00:00,  4.88it/s]
Validation loss: 0.1316, accuracy = 0.9650: 100%|██████████| 4/4 [00:00<00:00, 50.02it/s]


Epoch [4/25], Train Loss: 0.1803, Train Accuracy: 0.9322
Epoch [4/25], Val Loss: 0.1316, Val Accuracy: 0.9650
EarlyStopping counter: 1 out of 3


Training loss: 0.1550, accuracy = 0.9392: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s]
Validation loss: 0.1015, accuracy = 0.9725: 100%|██████████| 4/4 [00:00<00:00, 43.13it/s]


Epoch [5/25], Train Loss: 0.1550, Train Accuracy: 0.9392
Epoch [5/25], Val Loss: 0.1015, Val Accuracy: 0.9725
Validation loss decreased (0.128084 --> 0.101484).  Saving model ...


Training loss: 0.1321, accuracy = 0.9503: 100%|██████████| 40/40 [00:08<00:00,  4.84it/s]
Validation loss: 0.0774, accuracy = 0.9750: 100%|██████████| 4/4 [00:00<00:00, 53.81it/s]


Epoch [6/25], Train Loss: 0.1321, Train Accuracy: 0.9503
Epoch [6/25], Val Loss: 0.0774, Val Accuracy: 0.9750
Validation loss decreased (0.101484 --> 0.077367).  Saving model ...


Training loss: 0.1189, accuracy = 0.9553: 100%|██████████| 40/40 [00:08<00:00,  4.85it/s]
Validation loss: 0.0674, accuracy = 0.9800: 100%|██████████| 4/4 [00:00<00:00, 51.12it/s]


Epoch [7/25], Train Loss: 0.1189, Train Accuracy: 0.9553
Epoch [7/25], Val Loss: 0.0674, Val Accuracy: 0.9800
Validation loss decreased (0.077367 --> 0.067354).  Saving model ...


Training loss: 0.1167, accuracy = 0.9569: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s]
Validation loss: 0.0759, accuracy = 0.9800: 100%|██████████| 4/4 [00:00<00:00, 49.89it/s]


Epoch [8/25], Train Loss: 0.1167, Train Accuracy: 0.9569
Epoch [8/25], Val Loss: 0.0759, Val Accuracy: 0.9800
EarlyStopping counter: 1 out of 3


Training loss: 0.1097, accuracy = 0.9606: 100%|██████████| 40/40 [00:08<00:00,  4.85it/s]
Validation loss: 0.0551, accuracy = 0.9837: 100%|██████████| 4/4 [00:00<00:00, 50.59it/s]


Epoch [9/25], Train Loss: 0.1097, Train Accuracy: 0.9606
Epoch [9/25], Val Loss: 0.0551, Val Accuracy: 0.9837
Validation loss decreased (0.067354 --> 0.055099).  Saving model ...


Training loss: 0.1042, accuracy = 0.9619: 100%|██████████| 40/40 [00:08<00:00,  4.89it/s]
Validation loss: 0.0469, accuracy = 0.9800: 100%|██████████| 4/4 [00:00<00:00, 26.84it/s]


Epoch [10/25], Train Loss: 0.1042, Train Accuracy: 0.9619
Epoch [10/25], Val Loss: 0.0469, Val Accuracy: 0.9800
Validation loss decreased (0.055099 --> 0.046921).  Saving model ...


Training loss: 0.0927, accuracy = 0.9680: 100%|██████████| 40/40 [00:07<00:00,  5.01it/s]
Validation loss: 0.0484, accuracy = 0.9837: 100%|██████████| 4/4 [00:00<00:00, 50.20it/s]


Epoch [11/25], Train Loss: 0.0927, Train Accuracy: 0.9680
Epoch [11/25], Val Loss: 0.0484, Val Accuracy: 0.9837
EarlyStopping counter: 1 out of 3


Training loss: 0.0953, accuracy = 0.9653: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s]
Validation loss: 0.0446, accuracy = 0.9825: 100%|██████████| 4/4 [00:00<00:00, 51.52it/s]


Epoch [12/25], Train Loss: 0.0953, Train Accuracy: 0.9653
Epoch [12/25], Val Loss: 0.0446, Val Accuracy: 0.9825
Validation loss decreased (0.046921 --> 0.044623).  Saving model ...


Training loss: 0.0869, accuracy = 0.9692: 100%|██████████| 40/40 [00:08<00:00,  4.79it/s]
Validation loss: 0.0438, accuracy = 0.9850: 100%|██████████| 4/4 [00:00<00:00, 28.01it/s]


Epoch [13/25], Train Loss: 0.0869, Train Accuracy: 0.9692
Epoch [13/25], Val Loss: 0.0438, Val Accuracy: 0.9850
Validation loss decreased (0.044623 --> 0.043768).  Saving model ...


Training loss: 0.0880, accuracy = 0.9686: 100%|██████████| 40/40 [00:08<00:00,  4.98it/s]
Validation loss: 0.0413, accuracy = 0.9850: 100%|██████████| 4/4 [00:00<00:00, 53.06it/s]


Epoch [14/25], Train Loss: 0.0880, Train Accuracy: 0.9686
Epoch [14/25], Val Loss: 0.0413, Val Accuracy: 0.9850
Validation loss decreased (0.043768 --> 0.041348).  Saving model ...


Training loss: 0.0805, accuracy = 0.9706: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s]
Validation loss: 0.0466, accuracy = 0.9800: 100%|██████████| 4/4 [00:00<00:00, 54.32it/s]


Epoch [15/25], Train Loss: 0.0805, Train Accuracy: 0.9706
Epoch [15/25], Val Loss: 0.0466, Val Accuracy: 0.9800
EarlyStopping counter: 1 out of 3


Training loss: 0.0920, accuracy = 0.9665: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s]
Validation loss: 0.0468, accuracy = 0.9787: 100%|██████████| 4/4 [00:00<00:00, 52.48it/s]


Epoch [16/25], Train Loss: 0.0920, Train Accuracy: 0.9665
Epoch [16/25], Val Loss: 0.0468, Val Accuracy: 0.9787
EarlyStopping counter: 2 out of 3


Training loss: 0.0801, accuracy = 0.9682: 100%|██████████| 40/40 [00:07<00:00,  5.00it/s]
Validation loss: 0.0333, accuracy = 0.9850: 100%|██████████| 4/4 [00:00<00:00, 46.44it/s]


Epoch [17/25], Train Loss: 0.0801, Train Accuracy: 0.9682
Epoch [17/25], Val Loss: 0.0333, Val Accuracy: 0.9850
Validation loss decreased (0.041348 --> 0.033337).  Saving model ...


Training loss: 0.0804, accuracy = 0.9720: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s]
Validation loss: 0.0517, accuracy = 0.9850: 100%|██████████| 4/4 [00:00<00:00, 53.88it/s]


Epoch [18/25], Train Loss: 0.0804, Train Accuracy: 0.9720
Epoch [18/25], Val Loss: 0.0517, Val Accuracy: 0.9850
EarlyStopping counter: 1 out of 3


Training loss: 0.0810, accuracy = 0.9713: 100%|██████████| 40/40 [00:08<00:00,  4.80it/s]
Validation loss: 0.0460, accuracy = 0.9825: 100%|██████████| 4/4 [00:00<00:00, 51.55it/s]


Epoch [19/25], Train Loss: 0.0810, Train Accuracy: 0.9713
Epoch [19/25], Val Loss: 0.0460, Val Accuracy: 0.9825
EarlyStopping counter: 2 out of 3


Training loss: 0.0755, accuracy = 0.9735: 100%|██████████| 40/40 [00:08<00:00,  4.85it/s]
Validation loss: 0.0407, accuracy = 0.9837: 100%|██████████| 4/4 [00:00<00:00, 43.59it/s]

Epoch [20/25], Train Loss: 0.0755, Train Accuracy: 0.9735
Epoch [20/25], Val Loss: 0.0407, Val Accuracy: 0.9837
EarlyStopping counter: 3 out of 3
Early stopping
------------------------------------------------------------



