In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import glob
import torchvision.io
import cv2
from tqdm.notebook import tqdm

In [2]:
# Custom dataset class for Lip Reading
class LipReadingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))
        if '.DS_Store' in self.classes:
            self.classes.remove('.DS_Store')
        
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name, 'train')
            video_files = glob.glob(os.path.join(class_dir, "*.mp4"))
            self.video_paths.extend(video_files)
            self.labels.extend([label] * len(video_files))

        # Debug prints
        print(f"Found {len(self.video_paths)} videos across {len(self.classes)} classes.")
        if len(self.video_paths) == 0:
            print("No videos found. Please check the dataset directory structure and paths.")

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        frames = self.load_video_frames(video_path)

        if self.transform:
            frames = self.transform(frames)

        return frames, label

    def load_video_frames(self, video_path):
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)  # Convert frame to grayscale
            frame = Image.fromarray(frame)  # Convert numpy array to PIL Image
            frames.append(frame)
        cap.release()
        return frames

In [3]:
# Transform for video frames
class ToTensor:
    def __call__(self, frames):
        tensor = torch.stack([transforms.ToTensor()(frame) for frame in frames])  
        return tensor

In [4]:
# Path to the processed_selected_mp4_files directory
root_dir = './processed_selected_mp4_files'  # Update this path

# Dataset and DataLoader
transform = ToTensor()
train_dataset = LipReadingDataset(root_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=lambda x: collate_fn(x))

# Collate function to handle variable-length sequences
def collate_fn(batch):
    videos, labels = zip(*batch)
    max_len = max(len(video) for video in videos)
    padded_videos = []
    for video in videos:
        pad_size = max_len - len(video)
        # Ensure the tensor dimensions match for concatenation

        padded_video = torch.cat([video, torch.zeros((pad_size, video.shape[1], video.shape[2], video.shape[3]))], dim=0)  # Add padding for grayscale frames
        
        # Resize video to (max_len, 1, 224, 224)
        resized_video = torch.stack([resize(frame, (224, 224)) for frame in padded_video])
        padded_videos.append(resized_video)
        # padded_videos.append(padded_video)
    return torch.stack(padded_videos), torch.tensor(labels)


Found 20 videos across 2 classes.


In [5]:
# Define the model architecture using VGG
class LipReadingModel(nn.Module):
    def __init__(self, num_classes=500):
        super(LipReadingModel, self).__init__()
        # VGG16 as feature extractor
        self.vgg = torchvision.models.vgg16(weights='DEFAULT')
        self.vgg.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)  # Modify first conv layer for grayscale

        self.vgg.classifier = nn.Identity()  # Remove final classification layer

        # RNN for sequence modeling
        self.rnn = nn.LSTM(input_size=512*7*7, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
        
        # Fully connected layer for classification
        self.fc = nn.Linear(256*2, num_classes)  # bidirectional doubles the output features

    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.vgg.features(c_in)
        c_out = c_out.view(batch_size, timesteps, -1)  # Flatten for LSTM
        r_out, _ = self.rnn(c_out)
        out = self.fc(r_out[:, -1, :])
        return out

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_classes = len(train_dataset.classes)  # Automatically get the number of classes
learning_rate = 0.001
num_epochs = 20
batch_size = 10

# Initialize model, loss, and optimizer
model = LipReadingModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop with progress bar
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
    for i, (videos, labels) in enumerate(progress_bar):
        videos = videos.to(device)
        labels = labels.to(device)

        outputs = model(videos)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = epoch_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}')


Epoch 1/20:   0%|          | 0/2 [00:00<?, ?batch/s]

Epoch [1/20], Average Loss: 0.7823


Epoch 2/20:   0%|          | 0/2 [00:00<?, ?batch/s]

Epoch [2/20], Average Loss: 0.7630


Epoch 3/20:   0%|          | 0/2 [00:00<?, ?batch/s]

Epoch [3/20], Average Loss: 0.7145


Epoch 4/20:   0%|          | 0/2 [00:00<?, ?batch/s]

Epoch [4/20], Average Loss: 0.6950


Epoch 5/20:   0%|          | 0/2 [00:00<?, ?batch/s]