In [15]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as f
import torchvision
from torchvision import transforms
import os
import csv
from skimage import io, transform

In [16]:
dataset_path = os.path.join(os.curdir, 'dataset')
with_mask_ds_path = os.path.join(dataset_path, 'with_mask')
without_mask_ds_path = os.path.join(dataset_path, 'without_mask')
csv_file = os.path.join(dataset_path, 'face-mask-detection-dataset.csv')

In [17]:
# Make a csv for dataset
if os.path.exists(csv_file):
    print("CSV File Found")
else:
    print("CSV Not Found")    
    print("Creating CSV File")
    with open('face-mask-detection-dataset.csv', mode='w') as f:
        col_names = ['image_name', 'label']
        writer = csv.DictWriter(f, fieldnames=col_names)
        writer.writeheader()
        for img in os.listdir(without_mask_ds_path):
            writer.writerow({'image_name': img, 'label': 0})
        for img in os.listdir(with_mask_ds_path):
            writer.writerow({'image_name': img, 'label': 1})
        f.close()

CSV File Found


In [18]:
class MaskDetectionDataset(Dataset):
    def __init__(self, csv_file, root_dir, transforms=None):
        self.csv_file = csv_file
        self.root_dir = root_dir
        self.transforms = transforms
    def __len__(self):
        return len(pd.read_csv(self.csv_file))
    def __getitem__(self, idx):
        # Index from CSV File
        with open(self.csv_file, 'r') as f:
            df = pd.read_csv(self.csv_file)
            col_names = ['image_name', 'label']
            row = df.iloc[idx+2]
            if (row['label']==0):
                img_name = os.path.join(without_mask_ds_path, row['image_name'])
                image = io.imread(img_name)
                label = 0
                if self.transforms:
                    image = self.transforms(image)
                return (image, label)
            else:
                img_name = os.path.join(with_mask_ds_path, row['image_name'])
                image = io.imread(img_name)
                label = 1
                if self.transforms:
                    image = self.transforms(image)
                return (image, label)


In [19]:
# Prepare Image by transformation

augmentation_transform = transforms.Compose({
    transforms.ToPILImage(),
    transforms.Resize((300,300)),
    transforms.ToTensor(),
})

dataset = MaskDetectionDataset(csv_file=csv_file, root_dir=dataset_path, transforms=augmentation_transform)

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lr = 1e-4
epochs = 20
batch_size = 32

In [21]:
print(len(dataset))

820


In [22]:
# Splitting dataset
train_set, test_set = torch.utils.data.random_split(dataset, [750, 70])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [24]:
for batch_idx, (data, label) in enumerate(train_loader):
    print(batch_idx)
    print(data)
    print(label)
    break

0.3373, 0.3373],
          ...,
          [0.5765, 0.5804, 0.6039,  ..., 0.6627, 0.6667, 0.6667],
          [0.6314, 0.6275, 0.6275,  ..., 0.6353, 0.6392, 0.6392],
          [0.6549, 0.6510, 0.6353,  ..., 0.6235, 0.6275, 0.6275]],

         [[0.9098, 0.9098, 0.9098,  ..., 0.5020, 0.5098, 0.5098],
          [0.9098, 0.9098, 0.9098,  ..., 0.5059, 0.5098, 0.5098],
          [0.9059, 0.9059, 0.9059,  ..., 0.5137, 0.5059, 0.5059],
          ...,
          [0.5843, 0.5882, 0.6118,  ..., 0.7725, 0.7765, 0.7765],
          [0.6431, 0.6392, 0.6353,  ..., 0.7451, 0.7490, 0.7490],
          [0.6667, 0.6627, 0.6471,  ..., 0.7333, 0.7373, 0.7373]],

         [[0.8980, 0.8980, 0.8980,  ..., 0.3882, 0.3961, 0.3961],
          [0.8980, 0.8980, 0.8980,  ..., 0.3922, 0.3961, 0.3961],
          [0.8941, 0.8941, 0.8941,  ..., 0.4000, 0.3961, 0.3961],
          ...,
          [0.6667, 0.6706, 0.6980,  ..., 0.9176, 0.9216, 0.9216],
          [0.7137, 0.7098, 0.7098,  ..., 0.8902, 0.8941, 0.8941],
          