In [1]:
import cv2
import numpy as np
import torch
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig



In [2]:
# Load the face cascade for face detection
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')


In [3]:
# Define the emotion labels
emotion_labels = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']



In [4]:
# Load your pre-trained model from the specified local path
model_path = "C:/Users/osyed/vit_model"
model = ViTForImageClassification.from_pretrained(model_path)
model.eval()  # Set the model to evaluation mode



ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [5]:
# Transformation to be applied to the images
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the tensor
])



In [18]:
# Path to the video file
video_path = 'C:/Users/osyed/OneDrive/Desktop/MELD/MELD-RAW/MELD.Raw/dev/dev_splits_complete/dia16_utt1.mp4'



In [19]:
# Initialize the video capture
cap = cv2.VideoCapture(video_path)

# Get the video's frame count to loop it
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

while True:
    for i in range(frame_count):
        # Capture frame-by-frame
        ret, frame = cap.read()

        if not ret:
            break

        # Convert the frame to grayscale for face detection
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

        # Detect faces in the frame
        faces = face_cascade.detectMultiScale(gray, 1.3, 5)

        # Loop over each detected face
        for (x, y, w, h) in faces:
            # Extract the face ROI (Region of Interest)
            face_roi = frame[y:y+h, x:x+w]

            # Resize the face ROI to the expected input shape of the model
            resized_face = cv2.resize(face_roi, (224, 224))

            # Convert the resized face to a PyTorch tensor
            face_tensor = transform(resized_face)
            face_tensor = face_tensor.unsqueeze(0)  # Add batch dimension

            # Make predictions using the model
            with torch.no_grad():
                predictions = model(face_tensor)
                predicted_index = torch.argmax(predictions.logits, axis=1).item()
                if predicted_index >= len(emotion_labels):
                    predicted_emotion = "Unknown"
                else:
                    predicted_emotion = emotion_labels[predicted_index]

            # Draw a rectangle around the face
            cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)

            # Display the predicted emotion label above the face rectangle
            cv2.putText(frame, predicted_emotion, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # Display the resulting frame
        cv2.imshow('Emotion Detection', frame)

        # Break the loop if 'q' is pressed
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # Reset the video to the beginning to loop
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

# Release the video capture and close all OpenCV windows
cap.release()
cv2.destroyAllWindows()

KeyboardInterrupt: 