In [1]:
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F


face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
device = 'cuda'
from torch_snippets import *

In [2]:
from torchvision import models
def get_model():
    model = models.mobilenet_v2(pretrained = True)
    # Freeze parameters so we don't backprop through them
    for param in model.parameters():
        param.requires_grad = False
    model.avgpool = nn.Sequential(
        nn.Conv2d(512,512, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten()
    )
    class ageGenderClassifier(nn.Module):
        def __init__(self):
            super(ageGenderClassifier, self).__init__()
            self.intermediate = nn.Sequential(
                nn.Linear(1280,512),
                nn.ReLU(),
                nn.Dropout(0.4),
                nn.Linear(512,128),
                nn.ReLU(),
                nn.Dropout(0.4),
                nn.Linear(128,64),
                nn.ReLU(),
            )
            self.age_classifier = nn.Sequential(
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
            self.gender_classifier = nn.Sequential(
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            x = self.intermediate(x)
            age = self.age_classifier(x)
            gender = self.gender_classifier(x)
            return gender, age
        
    model.classifier = ageGenderClassifier()
    
    return model.to(device)

model = get_model()

In [3]:
model = get_model()
model.classifier.load_state_dict(torch.load('model'))
model.eval()

MobileNetV2(
  (features): Sequential(
    (0): ConvBNActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momen

In [4]:
label_gender = {True: 'Feminino', False: 'Masculino'}

In [5]:
import albumentations as A

transform = A.Compose(
    [   A.Resize(224,224),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

In [27]:
def detect_face(img, rst):
    face_img = img
    face_rects = face_cascade.detectMultiScale(face_img,scaleFactor=1.2,minNeighbors=5)
    for (x,y,w,h) in face_rects:
        #output = torch.argmax(model(rst), dim=1)
        gender, age = model(rst.to(device))
        age = age.to('cpu').detach().numpy()
        gender = gender.to('cpu').detach().numpy()
        print(gender)
        gender = (gender > 0.5).squeeze()
        gender = label_gender[bool(gender)]
        age = str(int(age[0][0]*80))
        cv2.rectangle(face_img,(x,y),(x+w,y+h),(0,0,255),2)
        cv2.putText(face_img, ("Genero: " + gender), (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255,255,255), 2)
        cv2.putText(face_img, ("Idade: " + age), (x, y+300), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255,255,255), 2)
    return face_img

In [28]:
cap = cv2.VideoCapture(1)
cap = cv2.VideoCapture(0)

while True:
    ret,frame = cap.read()
    im = resize(frame, 224)
    im = np.expand_dims(im, axis=0)
    im = torch.tensor(im/255).permute(0,3,1,2).float()
    frame = detect_face(frame, im)
    cv2.imshow('detect gender and age', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
cap.release()
cv2.destroyAllWindows()