In [1]:
import os
import re
import math
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
from transformers import BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Обработка данных

In [2]:
bert_tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased')

In [3]:
# словарь глосс
def load_vocab(vocab_file):
    word2id = {}
    id2word = {}
    with open(vocab_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) != 2:
                continue
            token, id_str = parts
            token_id = int(id_str)
            word2id[token] = token_id
            id2word[token_id] = token
    return word2id, id2word

# из строки глосс в индексы 
def tokenize_text(text, vocab, add_special_tokens=True):
    if not isinstance(text, str):
        text = ""
    tokens = text.strip().split()
    token_ids = [vocab.get(tok, vocab.get("<unk>", 0)) for tok in tokens]
    if add_special_tokens:
        bos = vocab.get("<bos>", 1)
        eos = vocab.get("<eos>", 2)
        token_ids = [bos] + token_ids + [eos]
    return torch.tensor(token_ids, dtype=torch.long)

# токенизатор для кодирования предложения
def tokenize_transcript_with_bert(text):
    if not isinstance(text, str):
        text = ""
    encoded = bert_tokenizer.encode(text, add_special_tokens=True)
    return torch.tensor(encoded, dtype=torch.long)


## Model

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: (B, T, d_model)
        return x + self.pe[:, :x.size(1)]

In [13]:
class SpatialEmbedding(nn.Module):
    def __init__(self, cnn_output_dim=512):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, cnn_output_dim, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(cnn_output_dim),
            nn.ReLU(),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.size()
        x = x.view(B * T, C, H, W)
        features = self.conv(x)
        features = self.pool(features)
        features = features.view(B, T, -1)  # (B, T, cnn_output_dim)
        return features

In [5]:
class SignLanguageTransformer(nn.Module):
    def __init__(self, cnn_output_dim=512, d_model=512, num_encoder_layers=3, num_decoder_layers=3,
                 nhead=8, gloss_vocab_size=3194, target_vocab_size=3194, dropout=0.1):
        super().__init__()
        self.spatial_embed = SpatialEmbedding(cnn_output_dim)
        self.input_linear = nn.Linear(cnn_output_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.gloss_projection = nn.Linear(d_model, gloss_vocab_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.word_embedding = nn.Embedding(target_vocab_size, d_model)
        self.translation_projection = nn.Linear(d_model, target_vocab_size)
    
    def forward(self, video_frames, target_seq=None):
        # video_frames: (B, T, C, H, W)
        B = video_frames.size(0)
        spatial_feats = self.spatial_embed(video_frames)     # (B, T, cnn_output_dim)
        x = self.input_linear(spatial_feats)                 # (B, T, d_model)
        x = self.pos_enc(x)                                  # (B, T, d_model)
        encoder_input = x.permute(1, 0, 2)                   # (T, B, d_model)
        memory = self.encoder(encoder_input)                 # (T, B, d_model)
        memory = memory.permute(1, 0, 2)                       # (B, T, d_model)
        gloss_logits = self.gloss_projection(memory)         # (B, T, gloss_vocab_size)
        
        translation_logits = None
        if target_seq is not None:
            target_emb = self.word_embedding(target_seq)     # (B, L, d_model)
            target_emb = self.pos_enc(target_emb)              # (B, L, d_model)
            tgt = target_emb.permute(1, 0, 2)                    # (L, B, d_model)
            L = target_seq.size(1)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(video_frames.device)
            decoder_output = self.decoder(tgt, memory.permute(1, 0, 2), tgt_mask=tgt_mask)
            decoder_output = decoder_output.permute(1, 0, 2)     # (B, L, d_model)
            translation_logits = self.translation_projection(decoder_output)  # (B, L, target_vocab_size)
        return gloss_logits, translation_logits

In [6]:
#  жадный алгоритм декодинга для перевода 
def greedy_decode(model, video_frames, max_len, start_symbol):
    model.eval()
    B = video_frames.size(0)
    with torch.no_grad():
        spatial_feats = model.spatial_embed(video_frames)
        x = model.input_linear(spatial_feats)
        x = model.pos_enc(x)
        encoder_input = x.permute(1, 0, 2)
        memory = model.encoder(encoder_input).permute(1, 0, 2)  # (B, T, d_model)
        ys = torch.full((B, 1), start_symbol, dtype=torch.long, device=video_frames.device)
        for i in range(max_len - 1):
            tgt = model.word_embedding(ys)
            tgt = model.pos_enc(tgt)
            tgt = tgt.permute(1, 0, 2)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(1)).to(video_frames.device)
            decoder_output = model.decoder(tgt, memory.permute(1, 0, 2), tgt_mask=tgt_mask)
            decoder_output = decoder_output.permute(1, 0, 2)
            out = model.translation_projection(decoder_output)
            prob = F.log_softmax(out[:, -1, :], dim=-1)
            next_word = torch.argmax(prob, dim=-1).unsqueeze(1)
            ys = torch.cat([ys, next_word], dim=1)
        return ys

# Beam search декодинг для перевода 
def beam_search_decode_debug(model, video_frames, max_len, start_symbol, beam_size=5, end_symbol=None, debug=False):

    temperature = 1.5
    model.eval()
    B = video_frames.size(0)
    all_best = []  # храним лучшую последовательность
    with torch.no_grad():
        spatial_feats = model.spatial_embed(video_frames)  # (B, T, cnn_output_dim)
        x = model.input_linear(spatial_feats)              # (B, T, d_model)
        x = model.pos_enc(x)                               # (B, T, d_model)
        encoder_input = x.permute(1, 0, 2)                 # (T, B, d_model)
        memory = model.encoder(encoder_input).permute(1, 0, 2)  # (B, T, d_model)
        
        for b in range(B):
            mem = memory[b:b+1]  # (1, T, d_model)
            mem_t = mem.permute(1, 0, 2)  # (T, 1, d_model)
            beams = [([start_symbol], 0.0)]
            if debug:
                print(f"Sample {b}: initial beam: {beams}")
            for t in range(max_len - 1):
                new_beams = []
                for seq, score in beams:
                    if end_symbol is not None and seq[-1] == end_symbol:
                        new_beams.append((seq, score))
                        continue
                    seq_tensor = torch.tensor(seq, dtype=torch.long, device=video_frames.device).unsqueeze(1)
                    tgt = model.word_embedding(seq_tensor)  # (L, 1, d_model)
                    tgt = model.pos_enc(tgt)
                    tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_tensor.size(0)).to(video_frames.device)
                    decoder_output = model.decoder(tgt, mem_t, tgt_mask=tgt_mask)  # (L, 1, d_model)
                    decoder_last = decoder_output[-1, 0, :]  # (d_model)
                    logits = model.translation_projection(decoder_last)  # (target_vocab_size)
                    # log_probs = F.log_softmax(logits, dim=-1)
                    log_probs = F.log_softmax(logits / temperature, dim=-1)  
                    log_probs[101] = float('-inf')

                    top_log_probs, top_indices = torch.topk(log_probs, beam_size)
                    top_log_probs = top_log_probs.cpu().numpy()
                    top_indices = top_indices.cpu().numpy()
                    for i in range(beam_size):
                        new_seq = seq + [int(top_indices[i])]
                        new_score = score + float(top_log_probs[i])
                        new_beams.append((new_seq, new_score))
                beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
                if debug:
                    print(f"Time step {t+1}: beams: {beams}")
            best_seq, best_score = beams[0]
            if debug:
                print(f"Best sequence for sample {b}: {best_seq} with score {best_score}")
            all_best.append(best_seq)
    
    max_length = max(len(seq) for seq in all_best)
    decoded = []
    for seq in all_best:
        padded_seq = seq + [0]*(max_length - len(seq))
        decoded.append(padded_seq)
    decoded = torch.tensor(decoded, dtype=torch.long, device=video_frames.device)
    return decoded

In [7]:
# загрузка фреймов видео + обработка 
def load_video_frames(video_path, num_frames=50, resize=(224,224)):

    cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
    if not cap.isOpened():
        raise ValueError(f"VideoCapture failed to open the file: {video_path}")
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, resize)
        frame = frame / 255.0
        frames.append(frame)
    cap.release()
    if len(frames) == 0:
        raise ValueError(f"Failed to load frames from {video_path}")
        
    T = len(frames)
    if T > num_frames:
        indices = np.linspace(0, T - 1, num_frames).astype(int)
        frames = [frames[i] for i in indices]
    elif T < num_frames:
        while len(frames) < num_frames:
            frames.append(frames[-1])
    frames = np.stack(frames, axis=0)  # (T, H, W, C)
    frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2)  # (T, C, H, W)
    return frames


In [8]:
class SignLanguageDataset(Dataset):
    def __init__(self, csv_file, video_dir, gloss_vocab, transcript_tokenizer):

        self.video_dir = video_dir
        self.gloss_vocab = gloss_vocab
        self.transcript_tokenizer = transcript_tokenizer
        df = pd.read_csv(csv_file)
        valid_indices = []
        for idx, row in df.iterrows():
            video_name = row['video_name']
            
            video_file = video_name if video_name.endswith(".mp4") else video_name + ".mp4"
            video_path = os.path.join(video_dir, video_file)
            if os.path.exists(video_path):
                valid_indices.append(idx)
            else:
                print(f"Warning: Video file not found for row {idx}: {video_path}")
        self.df = df.loc[valid_indices].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        video_name = row['video_name']
        video_file = video_name if video_name.endswith(".mp4") else video_name + ".mp4"
        video_path = os.path.join(self.video_dir, video_file)
        video_tensor = load_video_frames(video_path)  # (T, C, H, W)
        gloss_target = tokenize_text(row['glosses'], self.gloss_vocab, add_special_tokens=False)
        transcript_target = self.transcript_tokenizer(row['transcript'])
        return video_tensor, gloss_target, transcript_target



In [9]:
def collate_fn(batch):
    videos, glosses, transcripts = zip(*batch)
    videos = torch.stack(videos, dim=0)  # (B, T, C, H, W)

    def pad_sequences(sequences, pad_value=0):
        lengths = [seq.size(0) for seq in sequences]
        max_len = max(lengths)
        padded = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
        for i, seq in enumerate(sequences):
            padded[i, :seq.size(0)] = seq
        return padded

    padded_glosses = pad_sequences(glosses, pad_value=0)
    padded_transcripts = pad_sequences(transcripts, pad_value=0)
    return videos, padded_glosses, padded_transcripts


## Обучение

In [10]:
def train_model(model, train_loader, val_loader, num_epochs, device, checkpoint_dir="checkpoints"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
    translation_loss_fn = nn.CrossEntropyLoss(ignore_index=0)
    lambda_R = 1.0  # вес функцим потерь для распознавания 
    lambda_T = 1.0  # вес функции потерь для перевода
    best_bleu = 0.0
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0
        train_references = []
        train_hypotheses = []
        for videos, glosses, transcripts in tqdm(train_loader):
            videos = videos.to(device)
            transcripts = transcripts.to(device)
            optimizer.zero_grad()
            gloss_logits, translation_logits = model(videos, target_seq=transcripts)
            # print(gloss_logits)
            # print(translation_logits)
            gloss_log_probs = F.log_softmax(gloss_logits, dim=-1).permute(1, 0, 2)
            B, T, _ = gloss_logits.size()
            input_lengths = torch.full((B,), T, dtype=torch.long, device=device)
            target_gloss = glosses.view(-1)
            target_lengths = torch.full((B,), glosses.size(1), dtype=torch.long, device=device)
            loss_ctc = ctc_loss_fn(gloss_log_probs, target_gloss, input_lengths, target_lengths)
            loss_trans = translation_loss_fn(translation_logits.view(-1, translation_logits.size(-1)),
                                             transcripts.view(-1))
            # print(loss_ctc)
            # print(loss_trans)
            loss = lambda_R * loss_ctc + lambda_T * loss_trans
            # loss = lambda_T * loss_trans
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            with torch.no_grad():
                decoded = beam_search_decode_debug(model, videos, max_len=transcripts.size(1),
                                   start_symbol=bert_tokenizer.cls_token_id,
                                   beam_size=5, end_symbol=bert_tokenizer.sep_token_id, debug=False)
                for i in range(decoded.size(0)):
                    pred_tokens = decoded[i].tolist()
                    ref_tokens = transcripts[i].tolist()
                    pred_sentence = bert_tokenizer.decode([tok for tok in pred_tokens if tok not in [0, bert_tokenizer.cls_token_id, bert_tokenizer.sep_token_id]])
                    ref_sentence = bert_tokenizer.decode([tok for tok in ref_tokens if tok not in [0, bert_tokenizer.cls_token_id, bert_tokenizer.sep_token_id]])
                    train_hypotheses.append(pred_sentence.split())
                    train_references.append([ref_sentence.split()])
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}: Training Loss = {avg_loss:.4f}")
        
        train_bleu_score = corpus_bleu(train_references, train_hypotheses)
        print(f"Epoch {epoch}: Training BLEU = {train_bleu_score:.4f}")
        
        # валидация: декодинг и подсчет метрики 
        model.eval()
        references = []
        hypotheses = []
        with torch.no_grad():
            for videos, glosses, transcripts in tqdm(val_loader):
                videos = videos.to(device)
                transcripts = transcripts.to(device)
                # decoded = greedy_decode(model, videos, max_len=transcripts.size(1), start_symbol=1)
                decoded = beam_search_decode_debug(model, videos, max_len=transcripts.size(1),
                                   start_symbol=bert_tokenizer.cls_token_id,
                                   beam_size=5, end_symbol=bert_tokenizer.sep_token_id, debug=False)
                for i in range(decoded.size(0)):
                    pred_tokens = decoded[i].tolist()
                    # print("pred_tokens", pred_tokens)
                    ref_tokens = transcripts[i].tolist()
                    # print("ref_tokens", ref_tokens)
                    pred_sentence = bert_tokenizer.decode([tok for tok in pred_tokens if tok not in [0, bert_tokenizer.cls_token_id, bert_tokenizer.sep_token_id]])
                    ref_sentence = bert_tokenizer.decode([tok for tok in ref_tokens if tok not in [0, bert_tokenizer.cls_token_id, bert_tokenizer.sep_token_id]])
                    # print("pred_sentence", pred_sentence)
                    # print("ref_sentence", ref_sentence)
                    hypotheses.append(pred_sentence.split())
                    references.append([ref_sentence.split()])
        bleu_score = corpus_bleu(references, hypotheses)
        print(f"Epoch {epoch}: Validation BLEU = {bleu_score:.4f}")
        if bleu_score > best_bleu:
            best_bleu = bleu_score
            ckpt_path = os.path.join(checkpoint_dir, f"model_epoch{epoch}_BLEU{bleu_score:.4f}.pt")
            torch.save(model.state_dict(), ckpt_path)
            print(f"Checkpoint saved: {ckpt_path}")

In [11]:
csv_file = "./data/train.csv"                # таблица с данными
video_dir = "./data/video_segments"          # папка с видео
vocab_file = "./data/vocab.txt"                   # словарь
num_epochs = 10
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

word2id, id2word = load_vocab(vocab_file)
gloss_vocab = word2id
gloss_vocab_size = max(gloss_vocab.values()) + 1
target_vocab_size = bert_tokenizer.vocab_size

df = pd.read_csv(csv_file)
train_df = df[df['is_train'] == True][:100].reset_index(drop=True)
test_df = df[df['is_train'] == False][:10].reset_index(drop=True)

train_csv = "./data/train_table.csv"
test_csv = "./data/test_table.csv"
train_df.to_csv(train_csv, index=False)
test_df.to_csv(test_csv, index=False)
    

train_dataset = SignLanguageDataset(train_csv, video_dir, gloss_vocab, tokenize_transcript_with_bert)
test_dataset = SignLanguageDataset(test_csv, video_dir, gloss_vocab, tokenize_transcript_with_bert)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
model = SignLanguageTransformer(
    cnn_output_dim=512,
    d_model=256,
    num_encoder_layers=3,
    num_decoder_layers=3,
    nhead=4,
    gloss_vocab_size=gloss_vocab_size,
    target_vocab_size=target_vocab_size,
    dropout=0.4)
train_model(model, train_loader, test_loader, 15, device)