In [1]:
'''
Details for LRW dataset:
- 500 words (check lrw_list.txt)
- 800-1000 train videos per word
- 50 test and 50 validation videos per word
- Split into train, test, and validation sets
- All videos are 29 frames long (1.16 seconds)
- Word occurs roughly in the middle

TODO: (update as completed)
- Switch labels to numeric or one-hot (currently strings)
- Experiement with transformations of the data
- Experiement with trimming the video to cut out unneeded words, and reduce number of frames
'''

import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import time
import pytorchvideo #look up how to use this
import cv2
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
from IPython.display import Video, HTML
from playsound import playsound
from tqdm import tqdm
from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding

torch.cuda.empty_cache()

NUM_FRAMES = 29 # Make sure to set all new videos to this length
FPS = 25
TIME = 1.16
NUM_CLASSES = 50

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
random.seed(time.time())
RANDOM_SEED = random.randint(0, 2**32 - 1)


In [2]:
class LipReadDataset(Dataset):
    def __init__(self, root_dir, split, device, transform=None):
        """
        Initialize the dataset.
        Args:
            video_paths (list): List of file paths to video files.
            labels (list): List of corresponding labels.
            transform (callable, optional): Optional data transformations (e.g., resizing, normalization).
        """
        # self.word_folders = 
        # print(self.word_folders)
        # self.frames = []
        self.frame_titles = []
        self.labels = []
        self.device = device
        self.word_dict = {}
        self.num_classes = 0
        self.transform = transform
        # self.split = split
        # self.prefix = root_dir
        random.seed(RANDOM_SEED) #ensure words are the same bewtween datasets
        word_folders = os.listdir(root_dir)
        random.shuffle(word_folders)

        for word_folder in word_folders:
            set_path = os.path.join(root_dir, word_folder, split)
            self.word_dict[word_folder] = self.num_classes
            self.num_classes += 1
            # want to test with 10 words first
            if self.num_classes == NUM_CLASSES:
                break
            num_videos = 0
            video_files = os.listdir(set_path)
            random.seed(time.time()) #shuffle videos randomly
            random.shuffle(video_files)
            for video_file in video_files:
                video_path = os.path.join(set_path, video_file)
                if video_path.endswith('.txt'): # want to use later
                    continue
                num_videos += 1
                if num_videos == 100:
                    break
                # curr_frames = self.extract_frames(video_path)
                self.labels.append(word_folder)
                # self.frames.append(curr_frames)
                self.frame_titles.append(video_path)
                print(f"added {video_path} to {split} set")
        self.transform = transform
        assert self.num_classes == len(self.word_dict)

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get a specific sample from the dataset.
        Args:
            idx (int): Index of the sample.
        Returns:
            video (tensor): Video frames (e.g., as a sequence of images).
            label (int): Corresponding label.
        """
        # extract video
        # Apply any necessary transformations (e.g., resizing, normalization)
        # if self.transform:
        #     video = self.transform(video)
        # Convert the label to a one-hot encoded tensor
        return self.extract_frames(self.frame_titles[idx]), self.one_hot_encode(self.labels[idx])
    
    def one_hot_encode(self, labels):
        # Convert labels to one-hot encoding
        one_hot = torch.zeros(NUM_CLASSES, device=self.device)
        one_hot[self.word_dict[labels]] = 1
        return one_hot
        

    def extract_frames(self, video_path, duration=1.16, target_resolution = (40,65)):
        # Read the video
        cap = cv2.VideoCapture(video_path)
        # EXPERIEMENT WITH THIS
        # duration_frames = int(duration * FPS)
        # end_frame = NUM_FRAMES -  9 # skip the last 9 frames
        # start_frame = end_frame - duration_frames # get middle frames
        frames = []
        # cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
                # if cap.get(cv2.CAP_PROP_POS_FRAMES) == end_frame:
                #     break
                frame_cropped = frame[120:216, 80:176] # Crop the frame
                # frame_resized = cv2.resize(frame_cropped, target_resolution)
                if self.transform:
                    frame_tensor = self.transform(frame_cropped)
                # frame_gray = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2GRAY)
                # frame_tensor = torch.from_numpy(frame_gray).to(self.device)
                frame_tensor.to(self.device) #double check
                frames.append(frame_tensor)
            else:
                break
        cap.release()
        video_tensor = torch.stack(frames, dim=0).to(self.device)
        # video_tensor = video_tensor.permute(1, 0, 2, 3)  # assuming the original channel dimension is at index 1
        return video_tensor





In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert frames to PyTorch tensors
    # transforms.Resize((40, 65)),  # Resize frames to smaller resolution
    transforms.Grayscale(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])


root_dir = 'lipread_mp4'

print(device)
trainset = LipReadDataset(root_dir, split='train', device=device, transform=transform)
print("finished trainset")
testset = LipReadDataset(root_dir, split='test',device=device, transform=transform)
print("finished testset")
valset = LipReadDataset(root_dir, split='val', device=device, transform=transform)
print("finished valset")

cuda:0
added lipread_mp4\DAVID\train\DAVID_00060.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00226.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00232.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00834.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00900.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00414.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00407.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00228.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00529.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00446.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00810.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00234.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00543.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00343.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00028.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00721.mp4 to train set
added lipread_mp4\DAVID\train\DAVID_00373.mp4 to 

In [4]:
#send to dataloader
batch_size = 8
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True)

In [5]:
# # # TESTING IMAGE SIZE FOR CROP
# video_path = trainset.frame_titles[2]

# # def loop_video(video_path, loop_count, duration = 0.27):
# #     # Read the video
# #     cap = cv2.VideoCapture(video_path)
# #     duration_frames = round(duration * FPS)
# #     end_frame = NUM_FRAMES -  9 # skip the last 9 frames
# #     start_frame = end_frame - duration_frames - 3  # get middle frames
# #     print(end_frame, start_frame)
# #     frames = []
# #     cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
# #     # Loop the video
# #     frame_count = 0
# #     for _ in range(loop_count):
# #         cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)  # Set frame position to the beginning
# #         while cap.isOpened():
# #             ret, frame = cap.read()
# #             if ret:
# #                 frame_count += 1
# #                 if cap.get(cv2.CAP_PROP_POS_FRAMES) == end_frame:
# #                     break
# #                 frame_cropped = frame[120:200, 70:200]
# #                 cv2.imshow('Looped Video', frame_cropped)
# #                 if cv2.waitKey(25) & 0xFF == ord('q'):
# #                     break
# #             else:
# #                 break
# #     print(frame_count)
    
# #     # Release video capture
# #     cap.release()
# #     cv2.destroyAllWindows()

# # loop_video(video_path, 10)
# print(video_path)
# def extract_frames( video_path, duration=1.16, target_resolution = (40,65), transform=None):
#     # Read the video
#     cap = cv2.VideoCapture(video_path)
#     # EXPERIEMENT WITH THIS
#     # duration_frames = int(duration * FPS)
#     # end_frame = NUM_FRAMES -  9 # skip the last 9 frames
#     # start_frame = end_frame - duration_frames # get middle frames
#     frames = []
#     # cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

#     while cap.isOpened():
#         ret, frame = cap.read()
#         if ret:
#             # if cap.get(cv2.CAP_PROP_POS_FRAMES) == end_frame:
#             #     break
#             frame_cropped = frame[120:200, 70:200] # Crop the frame
#             #downscale the image by half to 40x65
#             # frame_cropped = cv2.resize(frame_cropped, target_resolution)
#             # frame_resized = cv2.resize(frame_cropped, target_resolution)
#             if transform:
#                 frame_tensor = transform(frame_cropped)
#             # frame_gray = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2GRAY)
#             # frame_tensor = torch.from_numpy(frame_gray).to(self.device)
#             frame_tensor = frame_tensor.to('cpu').numpy()
#             frames.append(frame_tensor)
#         else:
#             break
#     #show video for 10 loops
#     loop_count = 10
#     frame_count = 0
#     cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
#     for i in range(loop_count):
#         cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
#         for frame in frames:
#             print(frame.shape)
#             frame_count += 1
#             cv2.imshow('Looped Video', frame[0])
#             if cv2.waitKey(25) & 0xFF == ord('q'):
#                 break
#     cap.release()
#     cv2.destroyAllWindows()
#     # return video_tensor

# extract_frames(video_path, transform=transform)





In [6]:


# helpers

def exists(val):
    return val is not None

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)

# time token shift

def shift(t, amt):
    if amt is 0:
        return t
    return F.pad(t, (0, 0, 0, 0, amt, -amt))

class PreTokenShift(nn.Module):
    def __init__(self, frames, fn):
        super().__init__()
        self.frames = frames
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        f, dim = self.frames, x.shape[-1]
        cls_x, x = x[:, :1], x[:, 1:]
        x = rearrange(x, 'b (f n) d -> b f n d', f = f)

        # shift along time frame before and after

        dim_chunk = (dim // 3)
        chunks = x.split(dim_chunk, dim = -1)
        chunks_to_shift, rest = chunks[:3], chunks[3:]
        shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
        x = torch.cat((*shifted_chunks, *rest), dim = -1)

        x = rearrange(x, 'b f n d -> b (f n) d')
        x = torch.cat((cls_x, x), dim = 1)
        return self.fn(x, *args, **kwargs)

# feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# attention

def attn(q, k, v, mask = None):
    sim = einsum('b i d, b j d -> b i j', q, k)

    if exists(mask):
        max_neg_value = -torch.finfo(sim.dtype).max
        sim.masked_fill_(~mask, max_neg_value)

    attn = sim.softmax(dim = -1)
    out = einsum('b i j, b j d -> b i d', attn, v)
    return out

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        q = q * self.scale

        # splice out classification token at index 1
        (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))

        # let classification token attend to key / values of all patches across time and space
        cls_out = attn(cls_q, k, v, mask = cls_mask)

        # rearrange across time or space
        q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))

        # add rotary embeddings, if applicable
        if exists(rot_emb):
            q_, k_ = apply_rot_emb(q_, k_, rot_emb)

        # expand cls token keys and values across time or space and concat
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim = 1)
        v_ = torch.cat((cls_v, v_), dim = 1)

        # attention
        out = attn(q_, k_, v_, mask = mask)

        # merge back time or space
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # concat back the cls token
        out = torch.cat((cls_out, out), dim = 1)

        # merge back the heads
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

        # combine heads out
        return self.to_out(out)

# main classes

class TimeSformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_frames,
        num_classes,
        image_size = 96,
        patch_size = 16,
        channels = 1,
        depth = 12,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_dropout = 0.,
        rotary_emb = True,
        shift_tokens = False
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_size // patch_size) ** 2
        num_positions = num_frames * num_patches
        patch_dim = channels * patch_size ** 2

        self.heads = heads
        self.patch_size = patch_size
        self.to_patch_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, dim))

        self.use_rotary_emb = rotary_emb
        if rotary_emb:
            self.frame_rot_emb = RotaryEmbedding(dim_head)
            self.image_rot_emb = AxialRotaryEmbedding(dim_head)
        else:
            self.pos_emb = nn.Embedding(num_positions + 1, dim)


        self.layers = nn.ModuleList([])
        for _ in range(depth):
            ff = FeedForward(dim, dropout = ff_dropout)
            time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
            spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)

            if shift_tokens:
                time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff))

            time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff))

            self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))

        self.to_out = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, video, mask = None):
        b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
        assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'

        # calculate num patches in height and width dimension, and number of total patches (n)

        hp, wp = (h // p), (w // p)
        n = hp * wp

        # video to patch embeddings
        video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)

        tokens = self.to_patch_embedding(video)

        # add cls token

        cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
        x =  torch.cat((cls_token, tokens), dim = 1)

        # positional embedding

        frame_pos_emb = None
        image_pos_emb = None
        if not self.use_rotary_emb:
            x += self.pos_emb(torch.arange(x.shape[1], device = device))
        else:
            frame_pos_emb = self.frame_rot_emb(f, device = device)
            image_pos_emb = self.image_rot_emb(hp, wp, device = device)

        # calculate masking for uneven number of frames

        frame_mask = None
        cls_attn_mask = None
        if exists(mask):
            mask_with_cls = F.pad(mask, (1, 0), value = True)

            frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads)

            cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads)
            cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True)

        # time and space attention

        for (time_attn, spatial_attn, ff) in self.layers:
            x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x
            x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x
            x = ff(x) + x

        cls_token = x[:, 0]
        return self.to_out(cls_token)

  if amt is 0:


In [7]:
def plot(loss, acc, title):
    fig, ax = plt.subplots(2)
    fig.tight_layout(pad=3.0)
    fig.suptitle(title)
    ax[0].plot(loss)
    ax[0].set_title('Training Loss')
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Loss')
    ax[1].plot(acc)
    ax[1].set_title('Training Accuracy')
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Accuracy')
    plt.show()

In [8]:
def train_and_val(model, train_loader, val_loader, optimizer, criterion, num_epochs=10):
    model.train()
    train_loss_epoch = []
    train_acc_epoch = []
    val_loss_epoch = []
    val_acc_epoch = []
    for epoch in range(num_epochs):
        train_loss = 0.0
        train_total = 0
        train_correct = 0
        val_loss = 0.0
        val_total = 0
        val_correct = 0
        for data in tqdm(train_loader):
            inputs, labels = data
            inputs.to(device)
            labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            _, labels_class = labels.max(dim=1)  
            _, predicted_class = outputs.max(dim=1)  
            train_correct += (predicted_class == labels_class).sum().item()
            
            train_total += labels.size(0)
            #print
            train_loss += loss.item()
        

        #save model with name based on time
        epoch_state_dict = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}
        torch.save(epoch_state_dict, f'transformer_{int(time.time())}.pth')
        # Validation
        model.eval()
        with torch.no_grad():
            for data in tqdm(val_loader):
                inputs, labels = data
                inputs.to(device)
                labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, labels_class = labels.max(dim=1)  
                _, predicted_class = outputs.max(dim=1)  
                val_correct += (predicted_class == labels_class).sum().item()
                val_total+= labels.size(0)
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f'Epoch {epoch + 1}\nTrain Loss: {train_loss }, Accuracy: {100 * train_correct / train_total}% \nVal Loss: {val_loss}, Accuracy: {100 * val_correct / val_total}%')
        train_loss_epoch.append(train_loss)
        train_acc_epoch.append(100 * train_correct / train_total)
        val_loss_epoch.append(val_loss)
        val_acc_epoch.append(100 * val_correct / val_total)
    plot(train_loss_epoch, train_acc_epoch, 'Training')
    plot(val_loss_epoch, val_acc_epoch, 'Validation')

    print('Finished Training')




def test(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = datac
            inputs.to(device)
            labels.to(device)
            outputs = model(inputs)
            # print(labels.shape, outputs.shape)
            _, labels_class = labels.max(dim=1)  
            _, predicted_class = outputs.max(dim=1)  
            test_correct += (predicted_class == labels_class).sum().item()
            # print(predicted_class, labels_class, correct)
            # accuracy = correct_predictions / labels.size(0)  
            test_total += labels.size(0)
    print(f'Accuracy: {100 * test_correct / test_total}%')




In [9]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
#more complex
model = TimeSformer(
    dim = 512,
    num_frames = NUM_FRAMES,
    num_classes = NUM_CLASSES,
    image_size = 96,
    patch_size = 8,
    depth = 12,
    heads = 8,
    dim_head = 32,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    rotary_emb = True,
    shift_tokens = False
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
train_and_val(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
test(model, test_loader)

cnn2_state_dict = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}
torch.save(cnn2_state_dict, 'transformer.pth')

 14%|█▎        | 82/607 [10:09:59<1507:40:23, 10338.33s/it]