In this notebook we will use facenet_pytorch to detect faces in video,
and apply our mask classification to predict if person is wearing a mask 
or not.


In [9]:
from facenet_pytorch import MTCNN
import torch
import numpy as np
import cv2
from PIL import Image,ImageDraw
from IPython import display
import imutils

from torch import nn
from torch import optim
import os

import torch.nn.functional as F
from torchvision import datasets ,transforms, models
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
from torch.autograd import Variable



In [19]:
# transformation on Images
test_transforms=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),])

check_transforms=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),
                                     transforms.ToPILImage(),])

# check if gpu is available
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#Load Model
model=torch.load('classifymodel.pth')
model.eval()

Running on device: cuda:0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [11]:
mtcnn=MTCNN(keep_all=True, device=device)


In [12]:
def predict_image(image):
    image_tensor=test_transforms(image).float()
    image_tensor=image_tensor.unsqueeze_(0)
    input= Variable(image_tensor).to(device)
    output=model(input)
    
    index=output.data.cpu().numpy().argmax()
    return index


In [21]:
classes=['with mask','without mask']

In [23]:
cap=cv2.VideoCapture('Samples/masks_video.mp4')

while True:
    ret,image=cap.read()
    frame=image.copy()
    frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)

    boxes,_=mtcnn.detect(frame)

    color=(0,255,0)
    if boxes is not None:
        for box in boxes:
            startX,endX=int(box[0])-1,int(box[2])+1
            startY,endY=int(box[1])-1,int(box[3])+1
            cv2.rectangle(image,(box[0],box[1]),(box[2],box[3]) , color, 2)
            face = frame[startY:endY, startX:endX]
            img = Image.fromarray(face)
            img=check_transforms(img)
            index=predict_image(img)
            label=str(classes[index])
            color = (0, 255, 0)
            cv2.putText(image, label, (startX, startY - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 2)

    
    cv2.imshow("frame",image)
    k=cv2.waitKey(1) & 0xFF
    if k == ord("q"):
        print("Escapehit, closing...")
        break
cap.release()
cv2.destroyAllWindows()

None
None
None
None
None
None
None
[[867.7613  169.19453 909.641   230.05716]]
[[864.03973 171.12119 910.3459  233.61478]]
[[861.8857  171.92815 908.5219  234.23053]]
[[861.62646 169.4353  908.9337  233.67407]]
[[861.8046  171.44017 908.2118  233.86226]]
[[858.99896 170.24014 906.25665 234.7356 ]]
[[860.5203  170.38196 905.0072  233.14336]]
[[859.4362  170.44008 905.4762  234.04858]]
[[859.27515 171.17697 905.07446 234.524  ]]
[[856.3367  170.27267 904.487   234.42412]]
[[858.0965  170.11058 905.3689  237.01233]]
[[856.7256  170.32019 902.86304 235.60287]]
[[855.1676  170.63187 904.17914 234.16267]]
[[854.643   173.13678 901.7468  232.98206]]
[[853.1492  172.80197 900.4316  234.77664]]
[[853.4271  171.37517 901.431   234.3659 ]]
[[852.47    171.7359  899.3385  236.87485]]
[[849.95703 170.08063 898.7422  235.9982 ]]
[[849.6678  170.99136 897.4483  236.21764]]
[[848.5794  171.64142 896.83105 235.43431]]
[[850.693   173.28674 896.0069  237.9667 ]]
[[847.7133  171.2587  896.82495 236.5084 

[[ 996.9767   176.17357 1057.7069   259.31802]]
[[ 998.0751   176.54694 1058.9927   262.20074]]
[[ 998.2815   175.93588 1059.7275   261.77524]]
[[ 999.95325  176.54184 1061.6394   261.994  ]]
[[ 999.81134  176.06558 1061.4619   261.69263]]
[[ 999.4228   175.88356 1063.9646   260.63284]]
[[1001.7085   176.18243 1065.4137   260.4052 ]]
[[1002.1641   175.91121 1064.9371   261.23016]]
[[1003.2495  176.1388 1065.4147  261.5418]]
[[1003.61066  176.73781 1067.5642   263.27457]]
[[1004.0307   176.24855 1069.2877   262.22693]]
[[1004.79553  175.6243  1070.6616   262.05313]]
[[1009.62213  176.17967 1073.2347   261.13583]]
[[1009.027    173.5019  1073.9918   263.66754]]
[[1010.2999   174.69987 1075.022    264.47452]]
[[1009.1553   175.5522  1075.0236   264.64844]]
[[1011.78644  175.47104 1076.7483   264.5647 ]]
[[1012.9469   178.27267 1077.7344   266.4496 ]]
[[1011.4329   178.66585 1078.1552   266.3079 ]]
[[1012.11383  180.54033 1079.5331   271.06552]]
[[1014.00446  183.64824 1079.0703   265.9482

 [257.493   252.78036 280.42188 286.719  ]]
[[747.3393  166.83766 816.1482  243.25471]
 [568.1257  237.08206 629.3543  318.72034]
 [330.32178 258.9105  383.37265 315.56613]]
[[748.42505 166.72464 817.66943 243.50429]
 [571.085   236.88979 631.1087  318.72998]
 [333.28384 259.49307 379.13812 307.00668]]
[[748.44714 164.07384 820.1583  245.4445 ]
 [573.2756  238.64336 632.7265  318.85272]
 [334.3828  258.10114 379.93262 308.57153]]
[[750.77716 162.389   823.80194 242.71564]
 [578.90753 239.73877 636.70197 317.27374]
 [520.6445  204.44699 573.58203 280.59262]
 [336.23486 258.18158 386.87848 312.4145 ]]
[[752.37836 164.40248 822.4204  244.75618]
 [579.19836 234.35916 640.5104  320.97357]
 [511.02002 201.86049 569.4607  279.04724]
 [340.68088 257.47028 388.46512 309.0047 ]]
[[752.9995  164.01036 822.1584  245.10487]
 [580.75507 235.52911 641.9073  320.0245 ]
 [513.0945  209.87021 571.7759  283.12558]
 [340.44913 259.55182 385.9246  308.01263]]
[[756.80505 163.44032 827.8285  246.19972]
 [58

 [ 632.3004   250.79193  652.53503  279.3544 ]]
[[1087.4938   310.83392 1180.2614   435.7399 ]
 [ 726.60834  153.91823  798.50415  235.36368]
 [ 516.4953   157.06644  576.717    233.07492]]
[[1101.9364   300.80652 1204.0419   434.85025]
 [ 724.8462   155.31665  791.3245   235.32353]
 [ 518.7595   161.06526  578.1882   233.8216 ]]
[[1128.1705   295.4782  1226.3448   428.5256 ]
 [ 720.93524  156.31198  788.91473  236.14967]
 [ 520.042    156.64578  582.1659   234.99348]
 [ 634.3549   255.75615  650.5415   279.77655]]
[[1160.1333   290.2553  1254.2986   427.6028 ]
 [ 717.7746   155.54297  789.7784   241.32167]
 [ 521.9721   156.32886  583.41626  234.372  ]
 [ 636.5281   252.1394   655.24426  279.57596]]
[[1196.3702   295.12427 1296.1266   420.68124]
 [ 716.32935  156.25533  790.298    242.9229 ]
 [ 524.9311   159.93044  584.1559   233.61636]]
[[717.2677  156.24783 788.5923  239.5672 ]
 [525.8072  158.09099 586.05133 234.72714]]
[[712.1144  154.57567 789.4056  244.34712]
 [528.0896  159.44

 [506.758   199.6443  563.52893 271.58286]]
[[678.7947  202.6625  742.8289  294.1576 ]
 [836.6287  152.36034 901.1805  232.5017 ]
 [511.49146 200.36017 566.6003  271.07938]]
[[678.9568  200.56308 745.61786 293.57953]
 [845.5235  153.66089 908.3423  233.46637]
 [517.083   200.10036 571.6871  271.45688]]
[[838.3667  148.87375 908.1466  236.41489]
 [681.2792  202.88626 742.9917  289.7855 ]
 [521.396   192.964   582.90485 274.641  ]]
[[678.99054 202.60393 745.7914  292.08405]
 [844.6356  154.02069 909.144   238.95702]
 [528.91534 197.47366 585.2784  272.04678]]
[[843.6084  148.89378 917.6306  245.47273]
 [679.223   201.1335  747.15216 292.97522]
 [534.039   198.17249 590.4287  272.90192]]
[[842.61334 148.17856 919.0203  243.72818]
 [679.82733 199.33673 747.0992  293.61682]
 [538.1014  198.48894 596.062   275.7147 ]]
[[828.61804 141.4118  919.4133  249.75624]
 [681.2323  203.58565 745.4475  290.61896]
 [541.0975  199.98747 598.24817 274.6696 ]]
[[837.0023  146.865   913.1711  238.07094]
 [6