In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import glob
import torchvision.io
import cv2
from tqdm.notebook import tqdm


In [2]:
# Custom dataset class for Lip Reading
class LipReadingDataset(Dataset):
    def __init__(self, root_dir, train=True, less_data=False, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))
        if '.DS_Store' in self.classes:
            self.classes.remove('.DS_Store')
        
        t_set = ("train" if train == True else "test")
        
        self.classes = self.classes[:15]
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name, t_set)
            video_files = glob.glob(os.path.join(class_dir, "*.mp4"))
            self.video_paths.extend(video_files)
            self.labels.extend([label] * len(video_files))

        if less_data:
            self.video_paths = self.video_paths[::2]
            self.labels = self.labels[::2]


        # Debug prints
        print(f"Found {len(self.video_paths)} videos across {len(self.classes)} classes.")
        if len(self.video_paths) == 0:
            print("No videos found. Please check the dataset directory structure and paths.")

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        frames = self.load_video_frames(video_path)

        if self.transform:
            frames = self.transform(frames)

        return frames, label

    def load_video_frames(self, video_path):
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)  # Convert frame to grayscale
            frame = Image.fromarray(frame)  # Convert numpy array to PIL Image
            frames.append(frame)
        cap.release()
        return frames

In [3]:
# Transform for video frames
class ToTensor:
    def __call__(self, frames):
        tensor = torch.stack([transforms.ToTensor()(frame) for frame in frames])  
        return tensor

In [4]:
# Path to the processed_selected_mp4_files directory
root_dir = '/Users/Zachary/Documents/Courses/ECE228/project/processed_selected_mp4_files'  # Update this path

# Dataset and DataLoader
transform = ToTensor()
train_dataset = LipReadingDataset(root_dir, train=True, less_data=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=lambda x: collate_fn(x))
test_dataset = LipReadingDataset(root_dir, train=False, less_data=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True, collate_fn=lambda x: collate_fn(x))

# Collate function to handle variable-length sequences
def collate_fn(batch):
    videos, labels = zip(*batch)
    max_len = max(len(video) for video in videos)
    padded_videos = []
    for video in videos:
        pad_size = max_len - len(video)
        padded_video = torch.cat([video, torch.zeros((pad_size, video.shape[1], video.shape[2], video.shape[3]))], dim=0)
        resized_video = torch.stack([resize(frame, (224, 224)) for frame in padded_video])
        padded_videos.append(resized_video)
    return torch.stack(padded_videos), torch.tensor(labels)

Found 750 videos across 15 classes.
Found 375 videos across 15 classes.


In [5]:
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.cnn = models.resnet50(weights="DEFAULT")
        self.cnn.conv1 = nn.Conv2d(1, self.cnn.conv1.out_channels, kernel_size=self.cnn.conv1.kernel_size, 
                                   stride=self.cnn.conv1.stride, padding=self.cnn.conv1.padding, bias=False)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])  # Remove the classification layers

    def forward(self, x):
        batch_size, seq_length, c, h, w = x.size()
        x = x.view(batch_size * seq_length, c, h, w)
        features = self.cnn(x)
        features = features.view(batch_size, seq_length, -1)
        return features


In [6]:
class AttentionLSTM(nn.Module):
    def __init__(self, cnn_model, hidden_dim, num_classes, num_layers=2):
        super(AttentionLSTM, self).__init__()
        self.cnn = cnn_model
        self.lstm = nn.LSTM(input_size=2048 * 7 * 7, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=8, batch_first=True)
        self.fc = nn.Linear(hidden_dim*2, num_classes)

    def forward(self, x):
        with torch.no_grad():
            cnn_features = self.cnn(x)
        lstm_out, _ = self.lstm(cnn_features)
        attention_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        output = self.fc(attention_out[:, -1, :])  # Use the output from the last LSTM cell
        return output

In [7]:
def train(model, dataloader, criterion, optimizer, scheduler, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0

        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

        for i, (inputs, labels) in enumerate(train_progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            train_progress_bar.set_postfix(loss=loss.item())

        scheduler.step()
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

num_classes = len(train_dataset.classes)  # Automatically get the number of classes
learning_rate = 0.01
num_epochs = 10

cnn_model = CNNFeatureExtractor().to(device)
model = AttentionLSTM(cnn_model, hidden_dim=256, num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

train(model, train_loader, criterion, optimizer, scheduler, num_epochs=num_epochs)

# Save the trained model
model_save_path = 'resnet_lstm_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

cuda


Epoch 1/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 1/10, Loss: 5.9460, LR: 0.009755


Epoch 2/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 2/10, Loss: 2.8175, LR: 0.009045


Epoch 3/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 3/10, Loss: 2.7561, LR: 0.007939


Epoch 4/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 4/10, Loss: 2.7541, LR: 0.006545


Epoch 5/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 5/10, Loss: 2.7476, LR: 0.005000


Epoch 6/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 6/10, Loss: 2.7016, LR: 0.003455


Epoch 7/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 7/10, Loss: 2.6685, LR: 0.002061


Epoch 8/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 8/10, Loss: 2.6034, LR: 0.000955


Epoch 9/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 9/10, Loss: 2.5337, LR: 0.000245


Epoch 10/10:   0%|          | 0/75 [00:00<?, ?batch/s]

Epoch 10/10, Loss: 2.4709, LR: 0.000000
Model saved to resnet_lstm_model.pth


In [8]:
def check_accuracy(loader, model):

    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)  # move to device, e.g. GPU
            y = y.to(device)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, acc))
        return acc

acc = check_accuracy(test_loader, model)

Got 47 / 375 correct (0.13)
