In [91]:
import os 
data_pth = '../datasets'
videos_pth = os.path.join(data_pth,'videos')
aligns_pth = os.path.join(data_pth,'alignments')

In [92]:
speakers = os.listdir(videos_pth)
video_batches: list[str] =[ os.path.join(videos_pth,speaker) for speaker in speakers] 
align_batches = [os.path.join(aligns_pth,speaker) for speaker in speakers]


In [93]:
videos = []
aligns = []
for video_batch,align_batch in zip(video_batches,align_batches):
    video_files = os.listdir(video_batch)
    align_files = os.listdir(align_batch)
    for video_file in video_files:
        if not video_file.endswith('.mpg'):
            continue
        video_name = video_file.replace('.mpg','')
        align_name = video_name + '.align'
        align_path = os.path.join(align_batch,align_name)
        
        try:
            with open(align_path , 'r') as f:
                text = f.read()
            if len(text) < 1:
                continue
        except:
            print('align not found:',align_path)
            continue    

        videos.append(os.path.join(video_batch,video_file))
        aligns.append(os.path.join(align_batch,align_name))


In [94]:
len(videos)

1000

In [95]:
pip install opencv-python         

Note: you may need to restart the kernel to use updated packages.


In [96]:
import random 
import cv2 
import matplotlib.pyplot as plt
import numpy as np
import torch


In [97]:
vocab = 'abcdefghijklmnopqrstuvwxyz- '
vocab_size = len(vocab)
vti: dict[str, int] = {vocab[i]:i+1 for i in range(vocab_size)}
vti['-']=0
itv = {i:j for j,i in vti.items()}




In [98]:
from typing import Any

def extract_frames(video):
    frames: list[Any] = []
    
    cap = cv2.VideoCapture(video)

    while True :
        res , frame = cap.read()
        if not res:
            break 
        
        frames.append(frame) 
    return np.stack(frames) 

    
def extract_text(align):
    with open(align, 'r') as f:
        text = f.read()
    text = ''.join(char for char in text if char in vocab)
    return [vti[char] for char in text] 

In [99]:
from sklearn.model_selection import train_test_split 

X_train,X_test,y_train,y_test = train_test_split(videos,aligns,test_size=0.2,random_state=32)


In [100]:
import random
import numpy as np 

In [101]:
def get_batch(X,y,batch_size=4):
    assert len(X)==len(y), """X and y must be of the same size"""
    idxs = np.random.randint(0,len(X),batch_size)
    xs = []
    ys = []
    for idx in idxs:
        xs.append(X[idx])
        ys.append(y[idx])
    return np.stack([extract_frames(video) for video in xs]),[extract_text(align) for align in ys]


        
    

In [102]:
Xs,ys = get_batch(X_train,y_train)
# ys.shape


In [103]:
from typing import Any


import torch 
import torch.nn as nn

class Model(nn.Module):
    def __init__(self,in_channels,frames,height,width,hidden_size=100) -> None:
        super().__init__()
        self.width = width
        self.height  = height 
        self.frames = frames 
        self.in_channels= in_channels
        self.conv = nn.Sequential(
        torch.nn.Conv3d(in_channels=in_channels,out_channels=16,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1)),
        torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2),padding=(0,0,0))  ,       
        torch.nn.Conv3d(in_channels=16,out_channels=32,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1)),
        torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2),padding=(0,0,0)),
        torch.nn.Conv3d(in_channels=32,out_channels=64,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1)),
        torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2),padding=(0,0,0)))
        self.hidden_size = hidden_size
        #b,c,f,h,w 
        inp_shape: Any = (width//8)*(height//8)*64
        self.forget =nn.Sequential( nn.Linear( in_features=inp_shape + hidden_size,out_features= hidden_size),
                                    nn.Sigmoid()  # Sigmoid activation function
        )
        
        self.candidate =nn.Sequential( nn.Linear(in_features=hidden_size + inp_shape ,out_features=hidden_size),
                                      nn.Tanh())
    
        self.input =nn.Sequential( nn.Linear(in_features=hidden_size + inp_shape,out_features=hidden_size),nn.Sigmoid()  # Sigmoid activation function
        )
        self.output = nn.Sequential(
            nn.Linear(hidden_size+inp_shape,300),
            nn.Tanh(),
            nn.Linear(300,400),
            
            nn.Tanh(),
            
            nn.Linear(400,hidden_size),
            nn.LayerNorm(hidden_size),
                                    nn.Tanh()
                                    
                                    )
        
       

        self.f = nn.Linear(2*hidden_size,out_features=vocab_size)
    
    def bidirectional(self, X, isbackward=False):
        conv_out = self.conv(X).contiguous()
        # print(conv_out.shape)
        b, c, f, h, w = conv_out.size()
        conv_out = conv_out.view(b, f, -1)
        cell_state = torch.zeros(b, self.hidden_size, device=X.device)
        hidden_state = torch.zeros(b, f, self.hidden_size, device=X.device)
        outs = torch.zeros(b, f, vocab_size, device=X.device)
        for t in (range(f-1, -1, -1) if isbackward else range(f)):

            xt = conv_out[:, t, :]
            prev_idx = t+1 if isbackward else t-1
            valid_prev = (prev_idx >= 0 and prev_idx < f)
            prev_hs = hidden_state[:, prev_idx, :] if valid_prev else torch.zeros(b, self.hidden_size, device=X.device)

            # print(xt.shape)
            # print(xt)
            xt = torch.cat([xt, prev_hs], 1)
            # print(xt.shape)

            forget = self.forget(xt)

            # print('forget passed')
            input = self.input(xt)
            # print('input passed')
            candidate = self.candidate(xt)
            # print('candidate passed')
            output = self.output(xt)
            new_cell_state = forget * cell_state + input * candidate
            # print('new_cell_state passed')
            new_hidden_state = output * torch.tanh(input=new_cell_state)
            
            cell_state = new_cell_state
            hidden_state[:, t, :] = new_hidden_state

        return hidden_state

    def forward(self, X):
        forward_outs = self.bidirectional(X)
        backward_outs = self.bidirectional(X, isbackward=True)

        hidden_state = torch.cat([forward_outs, backward_outs], 2)
        outs = self.f(hidden_state)
        
        # Ensure numerical stability in log_softmax
        return nn.functional.log_softmax(outs, dim=2)
            
                
            

In [104]:

Xs,ys = get_batch(X_train,y_train)
Xs = torch.tensor(Xs).float()
Xs = Xs.permute(0,4,1,2,3)
Xs = Xs / 255.0
model = Model(in_channels=3,frames=75,height=Xs.shape[3],width=Xs.shape[4])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
Xs = Xs.to(device)

# Xs.shape

In [105]:
yts =torch.cat( [ torch.tensor(yi) for  yi in ys])
# len(yts)
yts = yts.to(device)

In [106]:
target_lengths = torch.tensor([len(yi) for yi in ys]).to(device)
input_lengths = torch.tensor([72 for _ in range(len(ys))]).to(device)

print("Input shapes:")
print(f"Xs shape: {Xs.shape}")
print(f"Predictions shape before permute: {model(Xs).shape}")
print(f"Predictions shape after permute: {model(Xs).permute(1,0,2).shape}")

Input shapes:
Xs shape: torch.Size([4, 3, 75, 288, 360])
Predictions shape before permute: torch.Size([4, 75, 28])
Predictions shape after permute: torch.Size([75, 4, 28])


In [107]:
assert all(target_lengths[i] <= input_lengths[i] for i in range(len(input_lengths)))
# input_lengths

In [None]:
ctc_loss = nn.CTCLoss(blank=0,zero_infinity=True) 
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 300
for epoch in range(epochs): 
    preds = model(Xs)    
    preds = preds.permute(1,0,2).to(device)

    loss = ctc_loss(preds, yts, input_lengths, target_lengths)
    
    if not torch.isnan(loss) and not torch.isinf(loss):
        print(f'epoch:{epoch} , loss: {loss.item()}')  # Use .item() to avoid memory leaks
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    else:
        print(f"Invalid loss detected: {loss}")




epoch:0 , loss: 4.348903656005859


  torch.nn.utils.clip_grad_norm(model.parameters(),max_norm=1)


epoch:1 , loss: 4.3899617195129395
epoch:2 , loss: 3.9921717643737793
epoch:3 , loss: 3.6761672496795654
epoch:4 , loss: 3.408346176147461


epoch:5 , loss: 3.2136080265045166
epoch:6 , loss: 3.0822715759277344
epoch:7 , loss: 3.00744891166687
epoch:8 , loss: 2.93611741065979
epoch:9 , loss: 2.8645260334014893
epoch:10 , loss: 2.807953357696533
epoch:11 , loss: 2.7496933937072754
epoch:12 , loss: 2.691725254058838
epoch:13 , loss: 2.641314744949341
epoch:14 , loss: 2.5985398292541504
epoch:15 , loss: 2.5616064071655273
epoch:16 , loss: 2.533553123474121
epoch:17 , loss: 2.499467134475708
epoch:18 , loss: 2.467519760131836
epoch:19 , loss: 2.436652421951294
epoch:20 , loss: 2.409245252609253
epoch:21 , loss: 2.387781858444214
epoch:22 , loss: 2.371035575866699
epoch:23 , loss: 2.356872081756592
epoch:24 , loss: 2.34299635887146
epoch:25 , loss: 2.3267741203308105
epoch:26 , loss: 2.3124356269836426
epoch:27 , loss: 2.3017921447753906
epoch:28 , loss: 2.288360834121704
epoch:29 , loss: 2.272719621658325
epoch:30 , loss: 2.2620315551757812
epoch:31 , loss: 2.2532238960266113
epoch:32 , loss: 2.2466378211975098
epoch:33 , loss:

KeyboardInterrupt: 

In [19]:
preds = model(Xs)    
preds = preds.permute(1,0,2)

In [20]:
loss = ctc_loss(preds,yts,input_lengths,target_lengths)