In [1]:
import cv2
import os
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
def load_and_prepare_data(video_dir, frame_dir):
    def count_frames(directory):
        frame_count = 0
        for root, dirs, files in os.walk(directory):
            frame_count += len(files)
        return frame_count

    def extract_frames(video_dir, frame_dir):
        os.makedirs(frame_dir, exist_ok=True)
        for video_file in os.listdir(video_dir):
            if video_file.endswith(".mp4"):
                video_path = os.path.join(video_dir, video_file)
                video_name = os.path.splitext(video_file)[0]
                output_folder = os.path.join(frame_dir, video_name)
                os.makedirs(output_folder, exist_ok=True)

                cap = cv2.VideoCapture(video_path)
                frame_count = 0
                while cap.isOpened():
                    ret, frame = cap.read()
                    if not ret:
                        break
                    frame_path = os.path.join(output_folder, f"frame_{frame_count:04d}.jpg")
                    cv2.imwrite(frame_path, frame)
                    frame_count += 1

                cap.release()
                print(f"Extracted {frame_count} frames from {video_file}")

    total_frames = count_frames(frame_dir)
    print(f"Total number of frames: {total_frames}")
    extract_frames(video_dir, frame_dir)


In [3]:
def preprocess_data(frame_dir, transform):
    class CREMADataset(Dataset):
        def __init__(self, frame_dir, transform=None):
            self.frame_dir = frame_dir
            self.transform = transform
            self.data = []

            self._label_encoder = preprocessing.LabelEncoder()
            emotions = ["ANG", "DIS", "FEA", "HAP", "NEU", "SAD", "SUR"]
            self._label_encoder.fit(emotions)

            for video_folder in os.listdir(frame_dir):
                video_path = os.path.join(frame_dir, video_folder)
                for frame_file in os.listdir(video_path):
                    if frame_file.endswith(".jpg"):
                        frame_path = os.path.join(video_path, frame_file)
                        label = self.get_label_from_video(video_folder)
                        self.data.append((frame_path, label))

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            frame_path, label = self.data[idx]
            image = Image.open(frame_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, label

        def get_label_from_video(self, video_name):
            emotion_code = video_name.split('_')[2]
            emotion_map = {"ANG": 0, "DIS": 1, "FEA": 2, "HAP": 3, "NEU": 4, "SAD": 5, "SUR": 6}
            return emotion_map.get(emotion_code, -1)

        @property
        def label_encoder(self):
            return self._label_encoder

    dataset = CREMADataset(frame_dir=frame_dir, transform=transform)
    indices = list(range(len(dataset)))
    train_indices, val_test_indices = train_test_split(indices, test_size=0.4, random_state=42)
    val_indices, test_indices = train_test_split(val_test_indices, test_size=0.5, random_state=42)

    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_subset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, val_loader, test_loader, dataset.label_encoder


In [4]:
def build_model(model_type, num_classes, device):
    if model_type == "resnet":
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_type == "mobilenet":
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.last_channel, num_classes)
    else:
        raise ValueError("Invalid model_type. Choose 'resnet' or 'mobilenet'.")

    return model.to(device)


In [5]:
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, num_epochs, checkpoint_dir, label_encoder):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = correct_val / total_val
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train Acc: {train_accuracy*100:.2f}% | Val Acc: {val_accuracy*100:.2f}%")
        print("-" * 50)

        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss
        }, checkpoint_path)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    cm_df = pd.DataFrame(cm, index=label_encoder.classes_, columns=label_encoder.classes_)

    plt.figure(figsize=(12, 8))
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

    print(classification_report(all_labels, all_preds, target_names=label_encoder.classes_))

    test_accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
    print(f"Final Test Accuracy: {test_accuracy * 100:.2f}%")

    torch.save(model.state_dict(), "emotion_recognition_model.pth")


In [6]:
def main():
    video_dir = r"../datasets/CREMA_D"
    frame_dir = r"../datasets/frames"
    checkpoint_dir = "./checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    load_and_prepare_data(video_dir, frame_dir)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_loader, val_loader, test_loader, label_encoder = preprocess_data(frame_dir, transform)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model("resnet", num_classes=7, device=device)

    train_and_evaluate_model(model, train_loader, val_loader, test_loader, 
                            device, num_epochs=10, checkpoint_dir=checkpoint_dir,
                            label_encoder=label_encoder)

if __name__ == "__main__":
    main()


Total number of frames: 564916
Extracted 77 frames from 1002_DFA_ANG_XX.mp4
Extracted 76 frames from 1002_DFA_DIS_XX.mp4
Extracted 82 frames from 1002_DFA_FEA_XX.mp4
Extracted 66 frames from 1002_DFA_HAP_XX.mp4
Extracted 75 frames from 1002_DFA_NEU_XX.mp4
Extracted 71 frames from 1002_DFA_SAD_XX.mp4


KeyboardInterrupt: 