In [19]:
#ref: https://github.com/yaoing/DAN
# #imports
import os


from PIL import Image
import numpy as np
import cv2

import torch
from torchvision import transforms

from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models

In [20]:
#DAN ARCHIETURE 

class DAN(nn.Module):
    def __init__(self, num_class=7,num_head=4, pretrained=True):
        super(DAN, self).__init__()
        
        resnet = models.resnet18(pretrained)
        
        if pretrained:
            checkpoint = torch.load('./models/resnet18_msceleb.pth')
            resnet.load_state_dict(checkpoint['state_dict'],strict=True)

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.num_head = num_head
        for i in range(num_head):
            setattr(self,"cat_head%d" %i, CrossAttentionHead())
        self.sig = nn.Sigmoid()
        self.fc = nn.Linear(512, num_class)
        self.bn = nn.BatchNorm1d(num_class)


    def forward(self, x):
        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self,"cat_head%d" %i)(x))
        
        heads = torch.stack(heads).permute([1,0,2])
        if heads.size(1)>1:
            heads = F.log_softmax(heads,dim=1)
            
        out = self.fc(heads.sum(dim=1))
        out = self.bn(out)
   
        return out, x, heads

class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        sa = self.sa(x)
        ca = self.ca(sa)

        return ca


class SpatialAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1,3),padding=(0,1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3,1),padding=(1,0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()


    def forward(self, x):
        y = self.conv1x1(x)
        y = self.relu(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))
        y = y.sum(dim=1,keepdim=True) 
        out = x*y
        
        return out 

class ChannelAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()    
        )


    def forward(self, sa):
        sa = self.gap(sa)
        sa = sa.view(sa.size(0),-1)
        y = self.attention(sa)
        out = sa * y
        
        return out

In [21]:
class Model():
    def __init__(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.data_transforms = transforms.Compose([
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                                ])
        self.labels = ['neutral', 'happy', 'sad', 'surprise', 'fear', 'disgust', 'anger']

        self.model = DAN(num_head=4, num_class=7, pretrained=False)
        checkpoint = torch.load(r"affecnet7_epoch6_acc0.6569.pth",
            map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'],strict=True)
        self.model.to(self.device)
        self.model.eval()

        self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades+'haarcascade_frontalface_default.xml')
    
    def detect(self, img0):
        img = cv2.cvtColor(np.asarray(img0),cv2.COLOR_RGB2BGR)
        faces = self.face_cascade.detectMultiScale(img,minNeighbors = 10,minSize=(60,60))
    
        return faces

    def fer(self, frame):

        img0 = Image.fromarray(frame)

        faces = self.detect(img0)

        if len(faces) == 0:
            return [],[]

        #  single face detection
        labels = []
        
        for (x,y,w,h) in faces:
             img = img0.crop((x,y, x+w, y+h))

             img = self.data_transforms(img)
             img = img.view(1,3,224,224)
             img = img.to(self.device)
            
             with torch.set_grad_enabled(False):
                 out, _, _ = self.model(img)
                 _, pred = torch.max(out,1)
                 index = int(pred)
                 label = self.labels[index]
                 labels.append(label)
        return labels,faces

In [47]:
def most_frequent(List):
    return max(set(List), key = List.count)

def get_no_faces(arr):
    try:
        return arr.shape[1]
    except IndexError:
        return 1

def most_frequent(List):
    counter = 0
    value = List[0]
     
    for i in List:
        curr_frequency = List.count(i)
        if(curr_frequency> counter):
            counter = curr_frequency
            value = i
 
    return value

In [51]:
if __name__ == "__main__":

    model = Model()



    cap = cv2.VideoCapture(0)
    
    fps = 0
    while True:
            tot_labels_set = []
            tot_faces_set = []

            fin_faces = []
            fin_labels = []

            for frm_of_set in range(20):
                _, frame = cap.read()
                labels,faces = model.fer(frame)
                if labels != []:
                    tot_labels_set.append(np.array(labels))
                    tot_faces_set.append(np.array(faces))
            tot_faces_set = np.array(tot_faces_set)
            tot_labels_set = np.array(tot_labels_set)

            
            if (tot_labels_set[1].shape)[0] == 1:
                fin_labels = most_frequent([i[0] for i in tot_labels_set if len(i)>0 ])
                fin_faces = tot_faces_set[np.where(tot_labels_set == [fin_labels])][0]
                label_position = (fin_faces[0],fin_faces[1])
                cv2.rectangle(frame,(fin_faces[0],fin_faces[1]),(fin_faces[0]+fin_faces[2],fin_faces[1]+fin_faces[3]),(0,145,255),1)
                cv2.putText(frame,fin_labels,label_position,cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
                cv2.imshow('Emotion Detector',frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

            

            

            break
            




            


            
            # if len(faces) >0:
            #     for idx,(x,y,w,h) in enumerate(faces):
                
            #         label_position = (x,y)
            #         cv2.rectangle(frame,(x,y),(x+w,y+h),(0,145,255),1)
            #         cv2.putText(frame,labels[idx],label_position,cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)

            #     cv2.imshow('Emotion Detector',frame)
            #     if cv2.waitKey(1) & 0xFF == ord('q'):
            #         break
            # else:
            #     cv2.putText(frame,'No Faces',(179,59),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
            #     cv2.imshow('Emotion Detector',frame)
            #     if cv2.waitKey(1) & 0xFF == ord('q'):
            #         break
                    
    cap.release()
    cv2.destroyAllWindows()     
            


In [45]:
# print(tot_faces_set)

ls = [i.shape[0] for i in tot_faces_set]
print(max(set(ls), key = ls.count))




2


In [41]:
print(tot_labels_set)

[array(['fear'], dtype='<U4') array(['fear'], dtype='<U4')
 array(['neutral'], dtype='<U7') array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7') array(['fear'], dtype='<U4')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')
 array(['neutral', 'anger'], dtype='<U7')
 array(['neutral', 'neutral'], dtype='<U7')]


In [46]:
v = ['ashwin','arun','balaji','arun']
print(max(v))

balaji
