In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class InvertedBlock(nn.Module):
    def __init__(self, squeeze=16, expand=64):

        super(InvertedBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(squeeze, expand, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(expand),
            nn.ReLU6(inplace=True),
            # Depthwise Convolution
            nn.Conv2d(expand, expand, kernel_size=3, stride=1, padding=1, groups=expand, bias=False),
            nn.BatchNorm2d(expand),
            nn.ReLU6(inplace=True),
            # Pointwise Convolution + Linear projection
            nn.Conv2d(expand, squeeze, kernel_size=1, stride=1, padding=0, bias=False),
        )

    def forward(self, x):
        return x + self.conv(x)


class VggFeatures(nn.Module):
    def __init__(self, drop=0.2):
        super().__init__()

        def conv_bn(inp, oup, ks):
            return nn.Sequential(
                nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=ks, padding=1),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
                )
        def invert(squeeze, expand):
            return InvertedBlock(squeeze, expand)
        
        self.layer1 = conv_bn(1, 64, 3)
        self.layer2 = conv_bn(64, 128, 5)
        self.layer3 = invert(128, 256)
        self.layer4 = invert(128, 256)
        self.layer5 = invert(128, 512)
        self.layer6 = invert(128, 512)
        self.lin1 = nn.Linear(128*2*2, 256)
        self.lin2 = nn.Linear(256, 128)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

          

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.pool(x)

        x = F.relu(self.layer2(x))
        x = self.pool(x)

        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x)

        x = self.layer5(x)
        x = self.layer6(x)
        x = self.pool(x)
        # print(x.shape)

        x = x.view(-1, 128 * 2 * 2)
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))

        return x


class Vgg(VggFeatures):
    def __init__(self, drop=0.2):
        super().__init__(drop)
        self.lin3 = nn.Linear(128, 7)

    def forward(self, x):
        x = super().forward(x)
        x = self.lin3(x)
        return x


from torchsummary import summary
net = Vgg()
net = net.eval()
summary(net, (1, 40, 40))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 40, 40]             640
       BatchNorm2d-2           [-1, 64, 40, 40]             128
              ReLU-3           [-1, 64, 40, 40]               0
         MaxPool2d-4           [-1, 64, 20, 20]               0
            Conv2d-5          [-1, 128, 18, 18]         204,928
       BatchNorm2d-6          [-1, 128, 18, 18]             256
              ReLU-7          [-1, 128, 18, 18]               0
         MaxPool2d-8            [-1, 128, 9, 9]               0
            Conv2d-9            [-1, 256, 9, 9]          32,768
      BatchNorm2d-10            [-1, 256, 9, 9]             512
            ReLU6-11            [-1, 256, 9, 9]               0
           Conv2d-12            [-1, 256, 9, 9]           2,304
      BatchNorm2d-13            [-1, 256, 9, 9]             512
            ReLU6-14            [-1, 25

In [2]:
from matplotlib import pyplot as plt
import os

class Logger:
    def __init__(self):
        self.loss_train = []
        self.loss_val = []

        self.acc_train = []
        self.acc_val = []

    def get_logs(self):
        return self.loss_train, self.loss_val, self.acc_train, self.acc_val

    def restore_logs(self, logs):
        self.loss_train, self.loss_val, self.acc_train, self.acc_val = logs

    def save_plt(self, hps):
        loss_path = os.path.join(hps['model_save_dir'], 'loss.jpg')
        acc_path = os.path.join(hps['model_save_dir'], 'acc.jpg')

        plt.figure()
        plt.plot(self.acc_train, 'g', label='Training Acc')
        plt.plot(self.acc_val, 'b', label='Validation Acc')
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Acc')
        plt.legend()
        plt.grid()
        plt.savefig(acc_path)

        plt.figure()
        plt.plot(self.loss_train, 'g', label='Training Loss')
        plt.plot(self.loss_val, 'b', label='Validation Loss')
        plt.title('Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid()
        plt.savefig(loss_path)
    def show_plt(self):

        plt.figure()
        plt.plot(self.acc_train, 'g', label='Training Acc')
        plt.plot(self.acc_val, 'b', label='Validation Acc')
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Acc')
        plt.legend()
        plt.grid()
        plt.show()

        plt.figure()
        plt.plot(self.loss_train, 'g', label='Training Loss')
        plt.plot(self.loss_val, 'b', label='Validation Loss')
        plt.title('Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid()
        plt.show()

In [3]:
import os
import torch
epoch = 158
path = os.path.join('cp_demo', 'epoch_' + str(epoch))
logger = Logger()
checkpoint = torch.load(path, map_location=torch.device('cpu'))

logger.restore_logs(checkpoint['logs'])
net.load_state_dict(checkpoint['params'])
print("Network Restored!")

Network Restored!


In [4]:
import cv2
from torchvision import transforms
import numpy as np
from PIL import Image

# Load the cascade
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')

mu,st = 0,255
test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(mu,), std=(st,))
    ])

lb = {0: "angry", 1: "disgust", 2: "fear", 3: "happy", 4: "sad", 5: "surprise", 6: "neutral"}

def check(gray_face):
    img = cv2.resize(gray_face, (40,40)).astype(np.float64)
    img = Image.fromarray(img)
    img = test_transform(img)
    img.unsqueeze_(0)
    outputs = net(img)
    _, preds = torch.max(outputs.data, 1)
    return int(preds.data[0])
# To capture video from webcam. 
cap = cv2.VideoCapture(0)

while True:
    # Read the frame
    _, img = cap.read()
    # Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # Detect the faces
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)
    # Draw the rectangle around each face
    for (x, y, w, h) in faces:
        face = gray[y:y+h,x:x+w]
        a = check(face)
        cv2.rectangle(img, (x, y), (x+w, y+h), (255, 0, 0), 2)
        img = cv2.putText(img, lb[a], (x, y), cv2.FONT_HERSHEY_SIMPLEX, 
                   1, (0, 0, 255), 1, cv2.LINE_AA)
    # Display
    cv2.imshow('img', img)
    # Stop if escape key is pressed
    k = cv2.waitKey(30) & 0xff
    if k==27:
        break
# Release the VideoCapture object
cap.release()
cv2.destroyAllWindows()